From 77488cfe630be4390593ee16d6ca0d24dc67f93f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 08:38:55 -0700 Subject: [PATCH 1/4] commit --- src/forge/controller/provisioner.py | 3 +- src/forge/env_constants.py | 1 + src/forge/observability/metric_actors.py | 9 +- src/forge/observability/metrics.py | 37 ++- tests/unit_tests/observability/conftest.py | 95 +++++++ .../unit_tests/observability/test_metrics.py | 231 ++++++++++++++++++ 6 files changed, 370 insertions(+), 6 deletions(-) create mode 100644 tests/unit_tests/observability/conftest.py create mode 100644 tests/unit_tests/observability/test_metrics.py diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index c823afb29..429c5760f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -262,7 +262,8 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh - # Spawn local logging actor on each process and register with global logger + # Spawn local fetcher actor on each process and register with global logger + # Can be disabled by FORGE_DISABLE_METRICS env var _ = await get_or_create_metric_logger(procs) return procs diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index 3adcdfc41..a4e024d83 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -14,4 +14,5 @@ METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" # Makes forge.observability.metrics.record_metric a no-op +# and disables spawning LocalFetcherActor in get_or_create_metric_logger FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS" diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index d67a66a83..edd1f24d8 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -6,10 +6,12 @@ import asyncio import logging +import os from typing import Any, Dict, Optional from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( get_logger_backend_class, LoggerBackend, @@ -95,8 +97,11 @@ async def get_or_create_metric_logger( f"Both should be True (already setup) or both False (needs setup)." ) - # Setup local_fetcher_actor if needed - if not proc_has_local_fetcher: + # Setup local_fetcher_actor if needed (unless disabled by environment flag) + if ( + not proc_has_local_fetcher + and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true" + ): local_fetcher_actor = proc.spawn( "local_fetcher_actor", LocalFetcherActor, global_logger ) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 990a301e0..4d527ec1b 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -437,7 +437,21 @@ async def init_backends( def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: if not self._is_initialized: - raise ValueError("Collector not initialized—call init first") + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg=( + "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." + " This happens when you try to use `record_metric` before calling `init_backends`." + " To disable this warning, please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger()`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "or set env variable `FORGE_DISABLE_METRICS=True`" + ), + ) + return if key not in self.accumulators: self.accumulators[key] = reduction.accumulator_class(reduction) @@ -458,8 +472,16 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - logger.debug( - f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." + "\nPlease call in your main file:\n" + "`mlogger = await get_or_create_metric_logger()`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "before calling `flush`", ) return {} @@ -662,6 +684,15 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): raise ValueError( f"Shared ID required but not provided for {self.name} backend init" ) + + # Clear any stale service tokens that might be pointing to dead processes + # In multiprocessing environments, WandB service tokens can become stale and point + # to dead service processes. This causes wandb.init() to hang indefinitely trying + # to connect to non-existent services. Clearing forces fresh service connection. + from wandb.sdk.lib.service import service_token + + service_token.clear_service_in_env() + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) self.run = wandb.init( id=shared_id, diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py new file mode 100644 index 000000000..a803c252d --- /dev/null +++ b/tests/unit_tests/observability/conftest.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared fixtures and mocks for observability unit tests.""" + +from unittest.mock import MagicMock, patch + +import pytest +from forge.observability.metrics import LoggerBackend, MetricCollector + + +class MockBackend(LoggerBackend): + """Mock backend for testing metrics logging without external dependencies.""" + + def __init__(self, logger_backend_config=None): + super().__init__(logger_backend_config or {}) + self.logged_metrics = [] + self.init_called = False + self.finish_called = False + self.metadata = {} + + async def init(self, role="local", primary_logger_metadata=None): + self.init_called = True + self.role = role + self.primary_logger_metadata = primary_logger_metadata or {} + + async def log(self, metrics, step): + self.logged_metrics.append((metrics, step)) + + async def finish(self): + self.finish_called = True + + def get_metadata_for_secondary_ranks(self): + return self.metadata + + +@pytest.fixture(autouse=True) +def clear_metric_collector_singletons(): + """Clear MetricCollector singletons before each test to avoid state leakage.""" + MetricCollector._instances.clear() + yield + MetricCollector._instances.clear() + + +@pytest.fixture(autouse=True) +def clean_metrics_environment(): + """Override the global mock_metrics_globally fixture to allow real metrics testing.""" + import os + + from forge.env_constants import FORGE_DISABLE_METRICS + + # Set default state for tests (metrics enabled) + if FORGE_DISABLE_METRICS in os.environ: + del os.environ[FORGE_DISABLE_METRICS] + + yield + + +@pytest.fixture +def mock_rank(): + """Mock current_rank function with configurable rank.""" + with patch("forge.observability.metrics.current_rank") as mock: + rank_obj = MagicMock() + rank_obj.rank = 0 + mock.return_value = rank_obj + yield mock + + +@pytest.fixture +def mock_actor_context(): + """Mock Monarch actor context for testing actor name generation.""" + with patch("forge.observability.metrics.context") as mock_context, patch( + "forge.observability.metrics.current_rank" + ) as mock_rank: + + # Setup mock context + ctx = MagicMock() + actor_instance = MagicMock() + actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]" + ctx.actor_instance = actor_instance + mock_context.return_value = ctx + + # Setup mock rank + rank_obj = MagicMock() + rank_obj.rank = 0 + mock_rank.return_value = rank_obj + + yield { + "context": mock_context, + "rank": mock_rank, + "expected_name": "TestActor_0XQr_r0", + } diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py new file mode 100644 index 000000000..3e864bdf7 --- /dev/null +++ b/tests/unit_tests/observability/test_metrics.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for core metrics functionality focusing on critical fixes in Diff 1.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import ( + ConsoleBackend, + get_logger_backend_class, + MeanAccumulator, + MetricCollector, + record_metric, + Reduce, + WandbBackend, +) + + +class TestCriticalFixes: + """Test critical production fixes from Diff 1.""" + + def test_uninitialized_push_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.push() logs warning when uninitialized.""" + collector = MetricCollector() + + # Should not raise error, just log warning and return + collector.push("test", 1.0, Reduce.MEAN) + assert any( + "Metric logging backends" in record.message for record in caplog.records + ) + + @pytest.mark.asyncio + async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.flush() logs warning when uninitialized.""" + collector = MetricCollector() + + # Should not raise error, just log warning and return empty dict + result = await collector.flush(step=1, return_state=True) + assert result == {} + assert any( + "Cannot flush collected metrics" in record.message + for record in caplog.records + ) + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "true"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_disabled(self, mock_collector_class): + """Test record_metric is no-op when FORGE_DISABLE_METRICS=true.""" + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_not_called() + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "false"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): + """Test record_metric works when FORGE_DISABLE_METRICS=false.""" + mock_collector = MagicMock() + mock_collector_class.return_value = mock_collector + + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_called_once() + mock_collector.push.assert_called_once() + + @patch("forge.observability.metrics.get_actor_name_with_rank") + def test_wandb_backend_creation(self, mock_actor_name): + """Test WandbBackend creation and basic setup without WandB dependency.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + config = { + "project": "test_project", + "group": "test_group", + "reduce_across_ranks": True, + } + backend = WandbBackend(config) + + assert backend.project == "test_project" + assert backend.group == "test_group" + assert backend.reduce_across_ranks is True + assert backend.share_run_id is False # default + + # Test metadata method + metadata = backend.get_metadata_for_secondary_ranks() + assert metadata == {} # Should be empty when no run + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_console_backend(self, mock_actor_name): + """Test ConsoleBackend basic operations.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + backend = ConsoleBackend({}) + + await backend.init(role="local") + + # Test log - should not raise + await backend.log({"test": 1.0}, step=1) + + await backend.finish() # Should not raise + + +class TestBasicAccumulators: + """Test basic accumulator functionality.""" + + def test_mean_accumulator(self): + """Test MeanAccumulator operations.""" + acc = MeanAccumulator(Reduce.MEAN) + + # Test initial state + assert acc.get_value() == 0.0 + state = acc.get_state() + assert state["sum"] == 0.0 + assert state["count"] == 0 + + # Test append and get_value + acc.append(10.0) + acc.append(20.0) + assert acc.get_value() == 15.0 + + # Test state + state = acc.get_state() + assert state["sum"] == 30.0 + assert state["count"] == 2 + assert state["reduction_type"] == "mean" + + # Test reset + acc.reset() + assert acc.get_value() == 0.0 + assert acc.get_state()["sum"] == 0.0 + assert acc.get_state()["count"] == 0 + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + + +class TestBackendFactory: + """Test backend factory function.""" + + def test_backend_factory(self): + """Test get_logger_backend_class factory function.""" + assert get_logger_backend_class("console") == ConsoleBackend + assert get_logger_backend_class("wandb") == WandbBackend + + with pytest.raises(ValueError, match="Unknown logger backend type"): + get_logger_backend_class("invalid_backend") + + +class TestMetricCollector: + """Test MetricCollector singleton behavior.""" + + def test_singleton_per_rank(self, mock_rank): + """Test MetricCollector singleton behavior per rank.""" + mock_rank.return_value.rank = 0 + collector1 = MetricCollector() + collector2 = MetricCollector() + assert collector1 is collector2 + + # Different rank should get different instance + mock_rank.return_value.rank = 1 + collector3 = MetricCollector() + assert collector1 is not collector3 + + +class TestMetricActorDisabling: + """Test environment flag to disable metric actors.""" + + async def _test_fetcher_registration(self, env_var_value, should_register_fetchers): + """Check if FORGE_DISABLE_METRICS=[True, False, None] correctly disables fetcher registration. + + Args: + env_var_value: Value to set for FORGE_DISABLE_METRICS (None means unset) + should_register_fetchers: Whether fetchers should be registered (True) or not (False) + """ + import os + + import forge.observability.metric_actors + from forge.env_constants import FORGE_DISABLE_METRICS + from monarch.actor import this_host + + # set fresh env + # Note: Environment variable setup is handled by clean_metrics_environment fixture + forge.observability.metric_actors._global_logger = None + + if env_var_value is not None: + os.environ[FORGE_DISABLE_METRICS] = env_var_value + + procs = this_host().spawn_procs(per_host={"cpus": 1}) + + if hasattr(procs, "_local_fetcher"): + delattr(procs, "_local_fetcher") + + # Test functionality + global_logger = await get_or_create_metric_logger(proc_mesh=procs) + + # Get results to check + proc_has_fetcher = hasattr(procs, "_local_fetcher") + global_has_fetcher = await global_logger.has_fetcher.call_one(procs) + + # Assert based on expected behavior + if should_register_fetchers: + assert ( + proc_has_fetcher + ), f"Expected process to have _local_fetcher when {env_var_value=}" + assert ( + global_has_fetcher + ), f"Expected global logger to have fetcher registered when {env_var_value=}" + else: + assert ( + not proc_has_fetcher + ), f"Expected process to NOT have _local_fetcher when {env_var_value=}" + assert ( + not global_has_fetcher + ), f"Expected global logger to NOT have fetcher registered when {env_var_value=}" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "env_value,should_register", + [ + ("false", True), + ("true", False), + (None, True), + ], + ) + async def test_fetcher_registration_with_env_flag(self, env_value, should_register): + """Test fetcher registration behavior with different environment flag values.""" + await self._test_fetcher_registration(env_value, should_register) From 8a24e715c15bb3d18337a75fc491a8a81e605291 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:17:08 -0700 Subject: [PATCH 2/4] update where we check FORGE_DISABLE_METRICS --- src/forge/controller/provisioner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 429c5760f..cf712079b 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -20,6 +20,7 @@ from forge.controller.launcher import BaseLauncher, get_launcher +from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import ProcessConfig, ProvisionerConfig @@ -263,8 +264,8 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh # Spawn local fetcher actor on each process and register with global logger - # Can be disabled by FORGE_DISABLE_METRICS env var - _ = await get_or_create_metric_logger(procs) + if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": + _ = await get_or_create_metric_logger(procs) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): From 3f3bc51bd69316cd403c874a72b7e9824ae9f190 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:17:48 -0700 Subject: [PATCH 3/4] remove protected import --- src/forge/observability/metrics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4d527ec1b..64843f110 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,6 +13,8 @@ from monarch.actor import context, current_rank +from forge.util.logging import log_once + logger = logging.getLogger(__name__) @@ -437,8 +439,6 @@ async def init_backends( def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: if not self._is_initialized: - from forge.util.logging import log_once - log_once( logger, level=logging.WARNING, @@ -472,8 +472,6 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - from forge.util.logging import log_once - log_once( logger, level=logging.WARNING, From 4fe26116d9562826f0fcc4cc37bbce48c40ccb18 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:23:18 -0700 Subject: [PATCH 4/4] protect import --- src/forge/controller/provisioner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index cf712079b..5ca331f32 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -21,7 +21,6 @@ from forge.controller.launcher import BaseLauncher, get_launcher from forge.env_constants import FORGE_DISABLE_METRICS -from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import ProcessConfig, ProvisionerConfig @@ -265,6 +264,8 @@ def bootstrap(env: dict[str, str]): # Spawn local fetcher actor on each process and register with global logger if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": + from forge.observability.metric_actors import get_or_create_metric_logger + _ = await get_or_create_metric_logger(procs) return procs @@ -286,6 +287,10 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async with self._lock: # Deregister local logger from global logger if hasattr(proc_mesh, "_local_fetcher"): + from forge.observability.metric_actors import ( + get_or_create_metric_logger, + ) + global_logger = await get_or_create_metric_logger(proc_mesh) await global_logger.deregister_fetcher.call_one(proc_mesh)