diff --git a/.meta/mast/qwen3_14b_mast.yaml b/.meta/mast/qwen3_14b_mast.yaml index 786f0103c..9560db4e1 100644 --- a/.meta/mast/qwen3_14b_mast.yaml +++ b/.meta/mast/qwen3_14b_mast.yaml @@ -19,9 +19,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + logging_mode: global_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/.meta/mast/qwen3_1_7b_mast.yaml b/.meta/mast/qwen3_1_7b_mast.yaml index 4065cf07a..604fc4f4e 100644 --- a/.meta/mast/qwen3_1_7b_mast.yaml +++ b/.meta/mast/qwen3_1_7b_mast.yaml @@ -19,9 +19,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + logging_mode: global_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/.meta/mast/qwen3_32b_mast.yaml b/.meta/mast/qwen3_32b_mast.yaml index 713c1f784..b9079a2c2 100644 --- a/.meta/mast/qwen3_32b_mast.yaml +++ b/.meta/mast/qwen3_32b_mast.yaml @@ -19,9 +19,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + logging_mode: global_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/.meta/mast/qwen3_4b_mast.yaml b/.meta/mast/qwen3_4b_mast.yaml index e11e2a25a..5e7442c12 100644 --- a/.meta/mast/qwen3_4b_mast.yaml +++ b/.meta/mast/qwen3_4b_mast.yaml @@ -19,9 +19,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + logging_mode: global_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/.meta/mast/qwen3_8b_mast.yaml b/.meta/mast/qwen3_8b_mast.yaml index 0405d767f..ec90db0ff 100644 --- a/.meta/mast/qwen3_8b_mast.yaml +++ b/.meta/mast/qwen3_8b_mast.yaml @@ -19,9 +19,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + logging_mode: global_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 85872681f..ef522e57b 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -304,7 +304,7 @@ async def main(cfg: DictConfig): else: provisioner = await init_provisioner() - metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + metric_logging_cfg = cfg.get("metric_logging", {}) mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index cc67952fd..6bb2ebab3 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -16,11 +16,11 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 27c71b3db..67d9e3a77 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -19,11 +19,11 @@ rollout_threads: 32 # make this 4x the number of policy replicas seems to work w # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index e19b751d3..683aa1503 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -12,11 +12,11 @@ off_by_n: 1 # Off by one by default # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 8efd3dace..555aa761e 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,7 +12,9 @@ from .metrics import ( BackendRole, ConsoleBackend, + get_logger_backend_class, LoggerBackend, + LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -43,6 +45,7 @@ "BackendRole", # Enums "Reduce", + "LoggingMode", # Utility functions "get_proc_name_with_rank", # Actor classes diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index ee6fe6277..e1f0a65df 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -10,7 +10,6 @@ from typing import Any, Union from monarch.actor import ( - Actor, context, endpoint, get_or_spawn_controller, @@ -18,11 +17,14 @@ this_proc, ) +from forge.controller.actor import ForgeActor + from forge.env import FORGE_DISABLE_METRICS from forge.observability.metrics import ( BackendRole, get_logger_backend_class, LoggerBackend, + LoggingMode, MetricCollector, reduce_metrics_states, ) @@ -63,9 +65,9 @@ async def get_or_create_metric_logger( mlogger = await get_or_create_metric_logger(process_name="Controller") # Initialize logging backends - await mlogger.init_backends.call_one({ - "console": {"reduce_across_ranks": True}, - "wandb": {"project": "my_project", "reduce_across_ranks": False} + await mlogger.init_backends({ + "console": {"logging_mode": "global_reduce"}, + "wandb": {"project": "my_project", "logging_mode": "per_rank_reduce"} }) # Initialize services... @@ -126,7 +128,7 @@ async def get_or_create_metric_logger( return global_logger -class LocalFetcherActor(Actor): +class LocalFetcherActor(ForgeActor): """Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh and accesses each rank's local MetricCollector. @@ -141,7 +143,6 @@ def __init__( ) -> None: self.global_logger = global_logger self.process_name = process_name - _is_initialized = False @endpoint async def flush( @@ -165,20 +166,20 @@ async def flush( @endpoint async def init_backends( self, - metadata_per_primary_backend: dict[str, dict[str, Any]], + metadata_per_controller_backend: dict[str, dict[str, Any]], config: dict[str, Any], global_step: int = 0, ) -> None: """Init per-rank logger backends and MetricCollector. Args: - metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. + metadata_per_controller_backend (dict[str, dict[str, Any]]): Metadata from controller backends for shared state. config (dict[str, Any]): Backend configurations with logging modes and settings. global_step (int): Initial step for metrics. """ collector = MetricCollector() await collector.init_backends( - metadata_per_primary_backend, + metadata_per_controller_backend, config, global_step, process_name=self.process_name, @@ -190,15 +191,16 @@ async def shutdown(self) -> None: await collector.shutdown() -class GlobalLoggingActor(Actor): +class GlobalLoggingActor(ForgeActor): """Coordinates metric logging across all ProcMeshes and their ranks. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), with per-rank and/or global reduction logging modes. - If a backend config has flag `reduce_across_ranks=False`, an instance of the backend - is initialized per-rank, otherwise it is done once globally. - + This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor + is automatically spawned per-procmesh in `forge.controller.provisioner.py` and registered + with this actor. The LocalFetcherActor is responsible for instantiating + the per-rank MetricCollector and working as a bridge between GlobalLoggingActor and processes. Flow: GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger @@ -208,57 +210,118 @@ def __init__(self): self.fetchers: dict[str, LocalFetcherActor] = {} self.config: dict[str, Any] | None = None self.global_logger_backends: dict[str, LoggerBackend] = {} - self.metadata_per_primary_backend: dict[str, dict[str, Any]] = {} + self.metadata_per_controller_backend: dict[str, dict[str, Any]] = {} + + def _validate_backend_config( + self, backend_name: str, config: dict[str, Any] + ) -> dict[str, Any]: + """Validate and normalize backend configuration.""" + if "logging_mode" not in config: + raise ValueError( + f"logging_mode is required for backend '{backend_name}' but was not provided. " + f"Please specify a logging_mode in your config. " + f"See forge.observability.metrics.LoggingMode for available options: " + f"{', '.join([mode.value for mode in LoggingMode])}." + ) - @endpoint - async def init_backends(self, config: dict[str, Any]) -> None: - """ - Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors - in all registered fetchers. + mode_str = config["logging_mode"] + mode = LoggingMode(mode_str) + + # Validate per_rank_share_run configuration + share_run = config.get("per_rank_share_run", False) + if mode == LoggingMode.GLOBAL_REDUCE and share_run: + logger.warning( + f"{backend_name}: per_rank_share_run=True is ignored in {mode.value} mode. " + "Setting it to False." + ) + share_run = False + + # WandB-specific warning for suboptimal configuration + if ( + backend_name == "wandb" + and mode == LoggingMode.PER_RANK_REDUCE + and share_run + ): + logger.warning( + "WandB: Using 'per_rank_reduce' with 'per_rank_share_run=True' is not recommended. " + "This configuration can lead to confusing metrics where reduced values from multiple ranks " + "are written to the same run/step, displaying only one of them. Consider either:\n" + " 1. Set 'per_rank_share_run=False' to create separate runs per rank, OR\n" + " 2. Use 'per_rank_no_reduce' for real-time streaming to a shared run" + ) - A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source - for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. + return { + **config, + "logging_mode": mode, + "per_rank_share_run": share_run, + } - The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, - a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, - and will log independently, i.e. each rank will have its own run in wandb. + @endpoint + async def init_backends(self, config: dict[str, Any]) -> None: + """Sets config in global actor and initializes existing backends and collectors. Later spawned actors + are initialized in `register_fetcher` endpoint. - Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states - and reduce them to a single value, which will be logged by the primary backend in this controller. + Controller backends (instantiated in the controller) can provide metadata to be shared with rank backends, + e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`. Args: - config (dict[str, Any]): Config for metric logging where keys are backend names, - e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} + config (dict[str, Any]): Config for metric logging where keys are backend names. + Each backend config supports: + - logging_mode (str | LoggingMode): Check LoggingMode for options. Defaults to "global_reduce". + - per_rank_share_run (bool, default False): For per-rank modes only. Whether ranks + share a single run/logger instance. Ignored for "global_reduce" mode. + - Additional backend-specific options (e.g., "project" for WandB) + + Example: + { + "console": {"logging_mode": "global_reduce"}, + "wandb": { + "logging_mode": "per_rank_no_reduce", + "per_rank_share_run": True, + "project": "my_project", + } + } + + Raises: + ValueError: If backend config is invalid or missing required fields. """ - self.config = config + self.config = {} + # Skip initialization if disabled by environment flag if FORGE_DISABLE_METRICS.get_value(): return + # Validate and normalize each backend config for backend_name, backend_config in config.items(): - backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role=BackendRole.GLOBAL, name="global_reduce") - - # Extract metadata from primary logger to be shared with secondary loggers - # and store it - reduce_across_ranks = backend_config.get("reduce_across_ranks", True) - if not reduce_across_ranks: - primary_backend_metadata = ( + self.config[backend_name] = self._validate_backend_config( + backend_name, backend_config + ) + + # Initialize backends based on logging mode + for backend_name, backend_config in self.config.items(): + mode = backend_config["logging_mode"] + + backend: LoggerBackend = get_logger_backend_class(backend_name)( + backend_config + ) + await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce") + + # Extract metadata from controller logger to be shared with per-rank loggers + if mode != LoggingMode.GLOBAL_REDUCE: + controller_metadata: dict[str, Any] = ( backend.get_metadata_for_secondary_ranks() or {} ) - self.metadata_per_primary_backend[ - backend_name - ] = primary_backend_metadata + self.metadata_per_controller_backend[backend_name] = controller_metadata - # Store global logger backends - if reduce_across_ranks: + # Store global logger backends for later flush + if mode == LoggingMode.GLOBAL_REDUCE: self.global_logger_backends[backend_name] = backend - # Eager init collectors on all registered fetchers in parallel, passing primary states and config + # Init collectors on all registered fetchers if self.fetchers: tasks = [ fetcher.init_backends.call( - self.metadata_per_primary_backend, self.config + self.metadata_per_controller_backend, self.config ) for fetcher in self.fetchers.values() ] @@ -278,7 +341,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> No if self.config: logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}") await fetcher.init_backends.call( - self.metadata_per_primary_backend, self.config + self.metadata_per_controller_backend, self.config ) @endpoint @@ -306,46 +369,52 @@ async def flush(self, global_step: int) -> None: config = self.config if config is None: logger.warning( - "GlobalLoggingActor flush() called before init_backends(). " - "No backends will be flushed." + "Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()." + " No backends will be flushed. Please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" ) return - # if reduce_across_ranks=True, we need to reduce the states from all ranks - # and log with the primary backend - requires_reduce = any( - backend_config.get("reduce_across_ranks", True) + + # Check if need to collect states from fetchers for global reduction + needs_state_collection = any( + backend_config["logging_mode"] == LoggingMode.GLOBAL_REDUCE for backend_config in config.values() ) logger.debug( - f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers" + f"Global flush for global step {global_step}: {len(self.fetchers)} fetchers" ) # Broadcast flush to all fetchers results = await asyncio.gather( *[ - f.flush.call(global_step, return_state=requires_reduce) + f.flush.call(global_step, return_state=needs_state_collection) for f in self.fetchers.values() ], return_exceptions=True, ) - if requires_reduce: - # Handle exceptions and extract values from ValueMesh results - all_local_states = [] - for result in results: - if isinstance(result, BaseException): - logger.warning(f"Flush failed on a fetcher: {result}") - continue - - # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] - for gpu_info, local_metric_state in result.items(): - if isinstance(local_metric_state, dict): - all_local_states.append(local_metric_state) - else: - logger.warning( - f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" - ) + if needs_state_collection: + + def extract_values_from_valuemesh(results) -> list[dict[str, Any]]: + all_local_states = [] + for result in results: + if isinstance(result, BaseException): + logger.warning(f"Flush failed on a fetcher: {result}") + continue + + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) + return all_local_states + + all_local_states = extract_values_from_valuemesh(results) if not all_local_states: logger.warning(f"No states to reduce for global_step {global_step}") @@ -354,12 +423,9 @@ async def flush(self, global_step: int) -> None: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) - # Log to each global logger_backend - for ( - logger_backend_name, - logger_backend, - ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, global_step) + # Log to global backends + for backend_name, backend in self.global_logger_backends.items(): + await backend.log_batch(reduced_metrics, global_step) @endpoint def has_fetcher(self, proc_id: str) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 980bb89fc..55a3c31a2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,13 +13,13 @@ from typing import Any import pytz -from monarch.actor import context, current_rank +from monarch.actor import current_rank from forge.observability.utils import get_proc_name_with_rank -from forge.util.logging import log_once +from forge.util.logging import get_logger, log_once -logger = logging.getLogger(__name__) +logger = get_logger("INFO") class BackendRole(Enum): @@ -33,6 +33,35 @@ class BackendRole(Enum): GLOBAL = "global" +class LoggingMode(Enum): + """Metric logging behavior for distributed training scenarios. + + Each mode serves different observability needs: + + GLOBAL_REDUCE = "global_reduce" + Best for: Metrics that are best visualized as a single value per step. + Behavior: All ranks accumulate → controller reduces → single log entry + Example use: 8 ranks training, want 1 loss value per training step averaged across all + Where: GlobalLoggingActor logs reduced values to backends on flush. + + PER_RANK_REDUCE = "per_rank_reduce" + Best for: Per-rank performance metrics, debugging individual rank behavior + Behavior: Each rank accumulates + logs its own reduced values + Example use: Monitor GPU utilization per rank, get 8 separate log entries per step + Where: MetricCollector on each rank log reduced values to backends on flush. + + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + Best for: Real-time streaming, time-series debugging + Behavior: Raw values logged immediately on record_metric() calls. Ignores reduce type. + Example use: See what every rank is doing in real time. + Where: MetricCollector on each rank log raw values to backends on push. + """ + + GLOBAL_REDUCE = "global_reduce" + PER_RANK_REDUCE = "per_rank_reduce" + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -56,7 +85,7 @@ def accumulator_class(self): class Metric: """Container for metric data including key, value, reduction type, and timestamp. - Timestamp is automatically set to current EST time if not provided. + Timestamp is automatically set to current UTC time if not provided. """ key: str @@ -70,55 +99,6 @@ def __post_init__(self): self.timestamp = datetime.now(pytz.UTC).timestamp() -def get_actor_name_with_rank() -> str: - """ - Extracts actor information from Monarch context to form a logging name. - - Returns: - str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to "UnknownActor" if context unavailable. - """ - # Add more defensive checks - ctx = context() - if ctx is None or ctx.actor_instance is None: - logger.warning("Context unavailable, using fallback actor name for logging.") - return "UnknownActor" - - actor_instance = ctx.actor_instance - rank = current_rank() - - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - rank_name = "UnknownActor" # fallback - if len(parts) >= 2: - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Extract world ID and proc rank - world_id = world_part.split("[")[0] if "[" in world_part else world_part - - # Extract clean actor name (remove "Configured" suffix if present) - if "[" in actor_part: - actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" - if actor_name.endswith("Configured"): - actor_name = actor_name[:-10] # Remove "Configured" - else: - actor_name = actor_part - - # Use last 4 characters of world_id as replica identifier - # This is deterministic, readable, and works for any number of replicas - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - # Use current_rank().rank as the local rank within the replica - local_rank = rank.rank - - rank_name = f"{actor_name}_{replica_id}_r{local_rank}" - - return rank_name - - def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: """Thin wrapper to send metrics to per-rank local MetricCollectors. @@ -126,14 +106,10 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None call `record_metric` anywhere in the code without moving the collector from function to function. - The collector methods are triggered per-rank by a - `forge.observability.metric_actors.LocalFetcherActor`, instantiated - during actor initialization. - - Records are flushed when `forge.observability.metric_actors.GlobalLoggingActor.flush()` - is called, typically triggered by the training loop at regular intervals. - Can be disabled globally by setting the environment variable `FORGE_DISABLE_METRICS=true`. + + Collected metrics are flushed to backends on flush(), generally: + GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger """ # Skip metrics collection if os.getenv("FORGE_DISABLE_METRICS", "false").lower() == "true": @@ -216,6 +192,7 @@ class MetricAccumulator(ABC): def __init__(self, reduction: Reduce) -> None: self.reduction_type = reduction + self.is_reset = True @abstractmethod def append(self, value: Any) -> None: @@ -249,9 +226,11 @@ def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.count = 0 + self.is_reset = True def append(self, value: Any) -> None: v = float(value.item() if hasattr(value, "item") else value) + self.is_reset = False self.sum += v self.count += 1 @@ -272,6 +251,7 @@ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return total_sum / total_count if total_count > 0 else 0.0 def reset(self) -> None: + self.is_reset = True self.sum = 0.0 self.count = 0 @@ -280,9 +260,11 @@ class SumAccumulator(MetricAccumulator): def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.total = 0.0 + self.is_reset = True def append(self, value: Any) -> None: v = float(value.item() if hasattr(value, "item") else value) + self.is_reset = False self.total += v def get_value(self) -> float: @@ -296,6 +278,7 @@ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return sum(s["total"] for s in states) def reset(self) -> None: + self.is_reset = True self.total = 0.0 @@ -303,22 +286,28 @@ class MaxAccumulator(MetricAccumulator): def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.max_val = float("-inf") + self.is_reset = True def append(self, value: Any) -> None: v = float(value.item() if hasattr(value, "item") else value) + self.is_reset = False self.max_val = max(self.max_val, v) def get_value(self) -> float: return self.max_val def get_state(self) -> dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} + return { + "reduction_type": self.reduction_type.value, + "max_val": self.max_val, + } @classmethod def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return max(s["max_val"] for s in states) def reset(self) -> None: + self.is_reset = True self.max_val = float("-inf") @@ -326,22 +315,28 @@ class MinAccumulator(MetricAccumulator): def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.min_val = float("inf") + self.is_reset = True def append(self, value: Any) -> None: v = float(value.item() if hasattr(value, "item") else value) + self.is_reset = False self.min_val = min(self.min_val, v) def get_value(self) -> float: return self.min_val def get_state(self) -> dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} + return { + "reduction_type": self.reduction_type.value, + "min_val": self.min_val, + } @classmethod def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return min(s["min_val"] for s in states) def reset(self) -> None: + self.is_reset = True self.min_val = float("inf") @@ -351,9 +346,11 @@ def __init__(self, reduction: Reduce) -> None: self.sum = 0.0 self.sum_sq = 0.0 self.count = 0 + self.is_reset = True def append(self, value: Any) -> None: v = float(value.item() if hasattr(value, "item") else value) + self.is_reset = False self.sum += v self.sum_sq += v * v self.count += 1 @@ -389,6 +386,7 @@ def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return max(0.0, variance) ** 0.5 def reset(self) -> None: + self.is_reset = True self.sum = 0.0 self.sum_sq = 0.0 self.count = 0 @@ -402,22 +400,23 @@ def reset(self) -> None: class MetricCollector: """Per-rank singleton for accumulating, retrieving and flushing metrics to backends. - A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally, - in the GlobalLoggingActor. + Supports multiple logging backends, each with different logging modes. + For options, check `forge.observability.metrics.LoggerBackend` and `forge.observability.metrics.LoggingMode`. - - Ensures one instance per process; actors call record_metric() which delegates here. + Behavior: + - Ensures one instance per rank; + - Using `record_metric()` delegates here; - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also - return non-reduced states for global aggregation. This can be different for each backend. - - Resets accumulators post-flush to avoid leaks across train steps; + return non-reduced states for global aggregation. + - Resets accumulators post-flush to avoid leaks across steps; """ _instances: dict[int, "MetricCollector"] = {} _singleton_rank: int def __new__(cls): - """Singleton per-rank, ensures one instance per process.""" + """Singleton per-rank, ensures one instance per rank.""" rank = current_rank().rank if rank not in cls._instances: @@ -426,6 +425,7 @@ def __new__(cls): inst._singleton_rank = rank else: inst = cls._instances[rank] + # Defensive check for bugs in singleton implementation - should never fail in normal operation if inst._singleton_rank != rank: raise ValueError( f"Singleton expected rank {inst._singleton_rank}, but saw {rank}" @@ -438,70 +438,103 @@ def __init__(self) -> None: self.accumulators: dict[str, MetricAccumulator] = {} self.rank = current_rank().rank - self.logger_backends: list[LoggerBackend] = [] + self.per_rank_reduce_backends: list[LoggerBackend] = [] + self.per_rank_no_reduce_backends: list[LoggerBackend] = [] + self.global_step: int = 0 # Set on `init_backends` and updated on `flush` self._is_initialized = False self.proc_name_with_rank: str | None = None async def init_backends( self, - metadata_per_primary_backend: dict[str, dict[str, Any]] | None, + metadata_per_controller_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], global_step: int = 0, process_name: str | None = None, ) -> None: - """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated - once globally. + """Initialize per-rank logger backends and MetricCollector state. + + A logger backend is represented by a backend class (e.g. WandBBackend, ConsoleBackend). + Backends are categorized by their logging_mode. For details, see `forge.observability.metrics.LoggingMode`. Args: - metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary - logger backend, e.g., {"wandb": {"run_id": "abc123"}}. - config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. - global_step (int, default 0): Initial step for metrics. + metadata_per_controller_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from controller + for backends that require shared state across processes, e.g., + {"wandb": {"shared_run_id": "abc123"}}. + config (Dict[str, Any]): Backend configurations where each key is a backend name + and value contains logging_mode and backend-specific settings. + e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}} + global_step (int, default 0): Initial step for logging. Can be used when + resuming from a checkpoint. process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: - logger.debug(f"Rank {self.rank}: MetricCollector already initialized") + logger.debug( + f"{self.proc_name_with_rank}: MetricCollector already initialized" + ) return self.global_step = global_step self.proc_name_with_rank = get_proc_name_with_rank(process_name) - # instantiate local backends if any + self.per_rank_reduce_backends: list[LoggerBackend] = [] + self.per_rank_no_reduce_backends: list[LoggerBackend] = [] + + # Initialize backends based on logging mode for backend_name, backend_config in config.items(): - if backend_config.get("reduce_across_ranks", True): - continue # Skip local backend instantiation and use global instead + mode = backend_config["logging_mode"] + + # sanity check + if not isinstance(mode, LoggingMode): + raise TypeError( + f"Expected LoggingMode enum for {backend_name}.logging_mode, got {type(mode)}: {mode}." + ) + + # We should never hit this. Backend will be instantiated in GlobalLoggingActor. + if mode == LoggingMode.GLOBAL_REDUCE: + logger.debug("Skipping local instantiation for GLOBAL_REDUCE.") + continue - # get metadata from primary backend if any - primary_metadata = {} - if metadata_per_primary_backend: - primary_metadata = metadata_per_primary_backend.get(backend_name, {}) + # get metadata from controller backend, if any + controller_metadata = {} + if metadata_per_controller_backend: + controller_metadata = metadata_per_controller_backend.get( + backend_name, {} + ) # instantiate local backend - logger_backend = get_logger_backend_class(backend_name)(backend_config) - await logger_backend.init( + backend: LoggerBackend = get_logger_backend_class(backend_name)( + backend_config + ) + await backend.init( role=BackendRole.LOCAL, - primary_logger_metadata=primary_metadata, - name=self.proc_name_with_rank, + controller_logger_metadata=controller_metadata, + process_name=self.proc_name_with_rank, ) - self.logger_backends.append(logger_backend) + + # Categorize by logging mode + if mode == LoggingMode.PER_RANK_NO_REDUCE: + self.per_rank_no_reduce_backends.append(backend) + else: + self.per_rank_reduce_backends.append(backend) self._is_initialized = True def push(self, metric: Metric) -> None: """Process a metric according to configured logging modes. - Args: - metric: Metric dataclass containing key, value, reduction type, and timestamp. + Behavior depends on backend modes: + - PER_RANK_NO_REDUCE: Stream metric immediately to backends + - PER_RANK_REDUCE/GLOBAL_REDUCE: Accumulate for per step batch logging - Raises: - TypeError: If metric is not a Metric object. + Args: + metric (Metric): Metric dataclass Example: collector = MetricCollector() metric = Metric("loss", 0.5, Reduce.MEAN) - collector.push(metric) + collector.push(metric) # Streams immediately if no_reduce, else accumulates """ + # sanity check if not self._is_initialized: log_once( logger, @@ -520,7 +553,13 @@ def push(self, metric: Metric) -> None: # Validate metric object if not isinstance(metric, Metric): - raise TypeError(f"Expected {Metric} object, got {type(metric)}") + raise TypeError( + f"Expected {Metric} object, got {metric} of type {type(metric)}" + ) + + # For PER_RANK_NO_REDUCE backends: stream without reduce + for backend in self.per_rank_no_reduce_backends: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -540,7 +579,7 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state}, + dict[str, dict[str, Any]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: @@ -550,7 +589,7 @@ async def flush( msg=f"Cannot flush collected metrics for {get_proc_name_with_rank()}. " " MetricCollector.flush() called before init_backends()." "\nPlease call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "before calling `flush`", ) @@ -565,30 +604,37 @@ async def flush( # Snapshot states and reset immediately states = {} for key, acc in self.accumulators.items(): + # Skip state if nothing was accumulated + if acc.is_reset: + continue states[key] = acc.get_state() acc.reset() - # Reduce metrics from states for logging if any per-rank backend - if self.logger_backends: - # Use reduce_metrics_states for consistency - reduced_metrics = reduce_metrics_states([states]) + # Reduce and log to PER_RANK_REDUCE backends only (NO_REDUCE backends already logged in push) + if self.per_rank_reduce_backends: + metrics_for_backends = reduce_metrics_states([states]) - # Log to local logger_backends - for logger_backend in self.logger_backends: - await logger_backend.log(reduced_metrics, global_step) + for backend in self.per_rank_reduce_backends: + await backend.log_batch(metrics_for_backends, global_step) + + # Update step counter for streaming backends + # Note: This is incremented AFTER flush completes, so metrics recorded between + # flush(N) and flush(N+1) will stream with global_step=N+1. + self.global_step = global_step + 1 return states if return_state else {} async def shutdown(self): """Shutdown logger_backends if initialized.""" + if not self._is_initialized: logger.debug( f"Collector for {self.proc_name_with_rank} not initialized. Skipping shutdown" ) return - for logger_backend in self.logger_backends: - await logger_backend.finish() + for backend in self.per_rank_reduce_backends + self.per_rank_no_reduce_backends: + await backend.finish() ########### @@ -606,31 +652,45 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: dict[str, Any] | None = None, - name: str | None = None, + controller_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). Args: - role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). - Can be used to behave differently for primary vs secondary roles. - primary_logger_metadata (dict[str, Any] | None): From global backend for + role (BackendRole): BackendRole.GLOBAL (controller) or BackendRole.LOCAL (per-rank). + Can be used to behave differently for controller vs rank roles. + controller_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. - name (str | None): Name for logging. + process_name (str | None): Process name for logging. Raises: ValueError if missing metadata for shared local init. """ pass @abstractmethod - async def log(self, metrics: list[Metric], global_step: int) -> None: - """ - Log a batch of metrics to the backend. + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: + """Log batch of accumulated metrics to backend Args: metrics: List of Metric objects to log. - global_step: Step number for x-axis alignment across metrics. + global_step: Step number for x-axis alignment across metrics.""" + pass + + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to backend immediately. + + NOTE: This method is called synchronously. + If your backend requires async I/O operations: + - Use asyncio.create_task() for fire-and-forget logging + - Consider internal buffering to avoid blocking the caller + + Example for async backend: + def log_stream(self, metric, global_step): + asyncio.create_task(self._async_log(metric, global_step)) """ pass @@ -639,7 +699,7 @@ async def finish(self) -> None: pass def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None: - """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + """Return sharable state after controller init (e.g., for shared modes). Called only on controller backends.""" return None @@ -652,39 +712,42 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: dict[str, Any] | None = None, - name: str | None = None, + controller_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: + self.process_name = process_name - self.name = name - - async def log(self, metrics: list[Metric], global_step: int) -> None: + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) ) logger.info( - f"=== [{self.name}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + f"=== [{self.process_name}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" ) + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + logger.info(f"{metric.key}: {metric.value}") + async def finish(self) -> None: pass class WandbBackend(LoggerBackend): """ - Weights & Biases logging backend for distributed training. + Weights & Biases logging backend. - Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: - Track a single process: reduce_across_ranks=True - Track each process separately: reduce_across_ranks=False, share_run_id=False - Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + For logging mode details, see `forge.observability.metrics.LoggingMode` documentation. + + More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/ Configuration: - reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). - If False, enables per-rank logging; then use share_run_id to pick mode. - share_run_id (bool, default False): Only used if reduce_across_ranks=False. - True -> shared run across ranks; False -> separate runs per rank. + logging_mode (LoggingMode): Determines logging behavior + per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks. + If true, then a single wandb is created and all ranks log to it. Its particularly useful if + logging with no_reduce to capture a time based stream of information. Not recommended if reducing values. project (str): WandB project name group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ @@ -693,41 +756,37 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: super().__init__(logger_backend_config) self.project = logger_backend_config["project"] self.group = logger_backend_config.get("group", "experiment_group") - self.name = None + self.process_name = None self.run = None - self.reduce_across_ranks = logger_backend_config.get( - "reduce_across_ranks", True - ) - self.share_run_id = logger_backend_config.get("share_run_id", False) + self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) + self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) async def init( self, role: BackendRole, - primary_logger_metadata: dict[str, Any] | None = None, - name: str | None = None, + controller_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: - if primary_logger_metadata is None: - primary_logger_metadata = {} + if controller_logger_metadata is None: + controller_logger_metadata = {} - self.name = name + self.process_name = process_name - # Default global mode: only inits on controller - if self.reduce_across_ranks: + # GLOBAL_REDUCE mode: only inits on controller + if self.logging_mode == LoggingMode.GLOBAL_REDUCE: if role != BackendRole.GLOBAL: - logger.debug( - f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." - ) + logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.") return await self._init_global() - # Per-rank modes based on share_run_id bool - elif role == BackendRole.GLOBAL and self.share_run_id: + # Per-rank modes based on per_rank_share_run bool + elif role == BackendRole.GLOBAL and self.per_rank_share_run: await self._init_shared_global() elif role == BackendRole.LOCAL: - if self.share_run_id: - await self._init_shared_local(primary_logger_metadata) + if self.per_rank_share_run: + await self._init_shared_local(controller_logger_metadata) else: await self._init_per_rank() @@ -739,7 +798,9 @@ async def _init_global(self): async def _init_per_rank(self): import wandb - self.run = wandb.init(project=self.project, group=self.group, name=self.name) + self.run = wandb.init( + project=self.project, group=self.group, name=self.process_name + ) async def _init_shared_global(self): import wandb @@ -749,13 +810,13 @@ async def _init_shared_global(self): ) self.run = wandb.init(project=self.project, group=self.group, settings=settings) - async def _init_shared_local(self, primary_metadata: dict[str, Any]): + async def _init_shared_local(self, controller_metadata: dict[str, Any]): import wandb - shared_id = primary_metadata.get("shared_run_id") + shared_id = controller_metadata.get("shared_run_id") if shared_id is None: raise ValueError( - f"Shared ID required but not provided for {self.name} backend init" + f"Shared ID required but not provided for {self.process_name} backend init" ) # Clear any stale service tokens that might be pointing to dead processes @@ -766,7 +827,9 @@ async def _init_shared_local(self, primary_metadata: dict[str, Any]): service_token.clear_service_in_env() - settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) + settings = wandb.Settings( + mode="shared", x_primary=False, x_label=self.process_name + ) self.run = wandb.init( id=shared_id, project=self.project, @@ -774,29 +837,49 @@ async def _init_shared_local(self, primary_metadata: dict[str, Any]): settings=settings, ) - async def log(self, metrics: list[Metric], global_step: int) -> None: - if self.run: - # Convert metrics to WandB log format - log_data = {"global_step": global_step} - for metric in metrics: - log_data[metric.key] = metric.value - - self.run.log(log_data) - logger.info( - f"WandbBackend: Logged {len(metrics)} metrics at global_step {global_step}" + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: + if not self.run: + logger.debug( + f"WandbBackend: No run started, skipping log for {self.process_name}" ) - else: - logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") + return + + # Convert metrics to WandB log format + log_data = {} + for metric in metrics: + log_data[metric.key] = metric.value + + self.run.log(log_data, step=global_step) + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at step {global_step}" + ) + + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to WandB with both step and timestamp.""" + if not self.run: + return + + # Log with custom timestamp for precision + # Users can choose x-axis as timestamp in WandB UI and display as dateimte + log_data = { + metric.key: metric.value, + "timestamp": metric.timestamp, + } + + # note: here we dont use step since wandb keeps only the latest value for each step + self.run.log(log_data) def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: - if self.run and not self.reduce_across_ranks and self.share_run_id: + if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} async def finish(self) -> None: if self.run: self.run.finish() - logger.info(f"WandbBackend {self.name}: Finished run") + logger.info(f"WandbBackend {self.process_name}: Finished run") def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index eae50c2db..bcbdb6755 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -82,15 +82,13 @@ async def main(): group = f"grpo_exp_{int(time.time())}" # Config format: {backend_name: backend_config_dict} - # Each backend can specify reduce_across_ranks to control distributed logging behavior config = { - "console": {"reduce_across_ranks": True}, + "console": {"logging_mode": "global_reduce"}, "wandb": { - "project": "my_project", + "project": "toy_metrics", "group": group, - "reduce_across_ranks": False, - # Only useful if NOT reduce_across_ranks. - "share_run_id": False, # Share run ID across ranks -- Not recommended. + "logging_mode": "per_rank_reduce", # global_reduce, per_rank_reduce, per_rank_no_reduce + "per_rank_share_run": True, }, } diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 425352340..f41dec56a 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -32,7 +32,9 @@ async def run(cfg: DictConfig): await init_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) - metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) + metric_logging_cfg = cfg.get( + "metric_logging", {"console": {"logging_mode": "global_reduce"}} + ) mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py index 1c315b2e9..fd3c96687 100644 --- a/tests/unit_tests/observability/test_metric_actors.py +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -6,6 +6,8 @@ """Optimized unit tests for metric actors functionality.""" +from unittest.mock import patch + import pytest from forge.observability.metric_actors import ( @@ -13,6 +15,8 @@ GlobalLoggingActor, LocalFetcherActor, ) + +from forge.observability.metrics import LoggingMode from monarch.actor import this_host @@ -62,7 +66,7 @@ async def test_global_logger_basic_ops(self, global_logger): async def test_backend_init(self, local_fetcher): """Test backend initialization and shutdown.""" metadata = {"wandb": {"shared_run_id": "test123"}} - config = {"console": {"reduce_across_ranks": False}} + config = {"console": {"logging_mode": LoggingMode.PER_RANK_REDUCE}} await local_fetcher.init_backends.call_one(metadata, config, global_step=5) await local_fetcher.shutdown.call_one() @@ -71,7 +75,7 @@ async def test_backend_init(self, local_fetcher): class TestRegistrationLifecycle: """Test registration lifecycle.""" - @pytest.mark.timeout(3) + @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_registration_lifecycle(self, global_logger, local_fetcher): """Test complete registration/deregistration lifecycle.""" @@ -108,25 +112,38 @@ async def test_valid_backend_configs(self, global_logger): # Empty config await global_logger.init_backends.call_one({}) - # Valid configs for different reduce_across_ranks modes - for reduce_across_ranks in [True, False]: - config = {"console": {"reduce_across_ranks": reduce_across_ranks}} + # Valid configs for different logging_mode modes + for logging_mode in [LoggingMode.GLOBAL_REDUCE, LoggingMode.PER_RANK_NO_REDUCE]: + config = {"console": {"logging_mode": logging_mode}} await global_logger.init_backends.call_one(config) - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_invalid_backend_configs(self, global_logger): - """Test invalid backend configurations are handled gracefully.""" - # Empty config should work - await global_logger.init_backends.call_one({}) - - # Config with only project should work - config_with_project = {"console": {"project": "test_project"}} - await global_logger.init_backends.call_one(config_with_project) - - # Config with reduce_across_ranks should work - config_with_reduce = {"console": {"reduce_across_ranks": True}} - await global_logger.init_backends.call_one(config_with_reduce) + def test_invalid_backend_configs(self): + """Test invalid backend configurations and warnings using direct validation.""" + actor = GlobalLoggingActor() + + # Test 1: Invalid logging_mode should raise ValueError + with pytest.raises(ValueError, match="is not a valid LoggingMode"): + actor._validate_backend_config("console", {"logging_mode": "invalid_mode"}) + + # Test 2: WandB PER_RANK_REDUCE + per_rank_share_run=True should warn + with patch("forge.observability.metric_actors.logger.warning") as mock_warn: + config = { + "logging_mode": "per_rank_reduce", + "per_rank_share_run": True, + "project": "test_project", + } + + result = actor._validate_backend_config("wandb", config) + + # Should have logged warning about suboptimal config + mock_warn.assert_called_once() + warning_msg = str(mock_warn.call_args) + assert "not recommended" in warning_msg + + # Should still return valid config with LoggingMode enum + assert result["logging_mode"] == LoggingMode.PER_RANK_REDUCE + assert result["per_rank_share_run"] is True + assert result["project"] == "test_project" class TestErrorHandling: diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index b4a8ffcdf..cda3679a5 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -16,6 +16,7 @@ BackendRole, ConsoleBackend, get_logger_backend_class, + LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -88,7 +89,9 @@ async def test_backend_role_usage(self): await console_backend.init(role=BackendRole.LOCAL) # Test WandbBackend role validation without WandB initialization - wandb_backend = WandbBackend({"project": "test"}) + wandb_backend = WandbBackend( + {"project": "test", "logging_mode": "global_reduce"} + ) # Mock all the WandB init methods to focus only on role validation with patch.object(wandb_backend, "_init_global"), patch.object( @@ -298,14 +301,14 @@ def test_wandb_backend_creation(self): config = { "project": "test_project", "group": "test_group", - "reduce_across_ranks": True, + "logging_mode": "global_reduce", } backend = WandbBackend(config) assert backend.project == "test_project" assert backend.group == "test_group" - assert backend.reduce_across_ranks is True - assert backend.share_run_id is False # default + assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE + assert backend.per_rank_share_run is False # default # Test metadata method metadata = backend.get_metadata_for_secondary_ranks() @@ -318,10 +321,10 @@ async def test_console_backend(self): await backend.init(role=BackendRole.LOCAL) - # Test log - should not raise + # Test log_batch - should not raise # Create a test metric test_metric = Metric("test", 1.0, Reduce.MEAN) - await backend.log([test_metric], global_step=1) + await backend.log_batch([test_metric], global_step=1) await backend.finish() # Should not raise