From e162788f6c80df0ee50ae892d54b13999aeec604 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 19 Sep 2025 10:03:16 -0700 Subject: [PATCH 01/25] metric logger simple example --- src/forge/controller/metric_actors.py | 211 ++++++++++++++++++++++++++ src/forge/controller/metric_main.py | 146 ++++++++++++++++++ src/forge/controller/provisioner.py | 46 +++++- 3 files changed, 400 insertions(+), 3 deletions(-) create mode 100644 src/forge/controller/metric_actors.py create mode 100644 src/forge/controller/metric_main.py diff --git a/src/forge/controller/metric_actors.py b/src/forge/controller/metric_actors.py new file mode 100644 index 000000000..3d94eefd8 --- /dev/null +++ b/src/forge/controller/metric_actors.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from collections import defaultdict +from typing import Any, Dict, List + +from monarch.actor import Actor, endpoint + + +# ============================================================================ +# LocalLoggingActor +# ============================================================================ +class LocalLoggingActor(Actor): + """Local logging actor that accumulates metrics within a process.""" + + def __init__(self): + self._metrics = defaultdict(list) + + @endpoint + async def get_metrics(self) -> Dict[str, List[Any]]: + """Get all accumulated metrics (called by GlobalLoggingActor).""" + # Return copy and reset for next collection + result = dict(self._metrics) + self._metrics.clear() + print(f"LocalLoggingActor: Returning {len(result)} metric keys") + return result + + @endpoint + def push_metrics( + self, key: str, value: Any + ) -> None: # Note: not async for broadcast + """Store a metric value (called by service actors).""" + self._metrics[key].append(value) + print(f"LocalLoggingActor: Stored {key}={value}") + + +# ============================================================================ +# Backend System +# ============================================================================ +class Backend: + """Base class for logging backends.""" + + def push(self, metrics: Dict[str, Any], step: int) -> None: + pass + + +class ConsoleBackend(Backend): + """Simple console backend for testing.""" + + def push(self, metrics: Dict[str, Any], step: int) -> None: + print(f"\n=== METRICS STEP {step} ===") + for key, value in metrics.items(): + print(f" {key}: {value}") + print("========================\n") + + +# ============================================================================ +# GlobalLoggingActor +# ============================================================================ + + +class GlobalLoggingActor(Actor): + """Global logger that coordinates across all processes.""" + + def __init__(self): + self._loggers: Dict[str, LocalLoggingActor] = {} + self._backends: List[Backend] = [ConsoleBackend()] # Default console backend + + @endpoint + async def register(self, local_actor: LocalLoggingActor, name: str) -> None: + """Register a LocalLoggingActor from a process.""" + self._loggers[name] = local_actor + print(f"GlobalLoggingActor: Registered {name}") + + @endpoint + async def deregister(self, name: str) -> None: + """Deregister a LocalLoggingActor.""" + if name in self._loggers: + del self._loggers[name] + print(f"GlobalLoggingActor: Deregistered {name}") + + @endpoint + async def flush(self, step: int) -> None: + """Collect metrics from all processes and send to backends.""" + if not self._loggers: + print("GlobalLoggingActor: No loggers registered") + return + + print(f"GlobalLoggingActor: Flushing metrics for step {step}") + + # Collect from all local loggers + metrics_list = await asyncio.gather( + *[actor.get_metrics.call_one() for actor in self._loggers.values()] + ) + + # Simple aggregation - just combine all metrics + print("metrics_list", metrics_list) + all_metrics = {} + for metrics in metrics_list: + for key, values in metrics.items(): + if key not in all_metrics: + all_metrics[key] = [] + all_metrics[key].extend(values) + + # Send to all backends + for backend in self._backends: + backend.push(all_metrics, step) + + +# =========================================================================== + + +def debug_context(ctx, label: str = "DEBUG") -> None: + """ + Utility function to fully debug a context object and its nested attributes. + + Args: + ctx: The context object to debug + label: Label for this debug session + """ + print(f"\n=== {label} ===") + print(f"Context type: {type(ctx)}") + print(f"Context dir: {[attr for attr in dir(ctx) if not attr.startswith('__')]}") + + # Print all attributes + for attr in dir(ctx): + if not attr.startswith("__"): + try: + val = getattr(ctx, attr) + print(f" ctx.{attr}: {val} (type: {type(val)})") + + # f this is actor_instance, explore it + if attr == "actor_instance": + print(f" ๐Ÿ” DEEP DIVE INTO ACTOR_INSTANCE:") + print(f" Type: {type(val)}") + + # Get all attributes + all_attrs = [a for a in dir(val) if not a.startswith("__")] + print(f" All attributes: {all_attrs}") + + # Check for anything related to proc, mesh, logger, etc. + interesting_attrs = [ + a + for a in all_attrs + if any( + keyword in a.lower() + for keyword in [ + "proc", + "mesh", + "log", + "local", + "spawn", + "process", + "actor", + ] + ) + ] + print(f" Interesting attributes: {interesting_attrs}") + + # Print ALL attributes with their values + for sub_attr in all_attrs: + try: + sub_val = getattr(val, sub_attr) + print( + f" โ””โ”€ {sub_attr}: {sub_val} (type: {type(sub_val)})" + ) + + # If it's an object, go one level deeper + if hasattr(sub_val, "__dict__") or hasattr( + sub_val, "__slots__" + ): + try: + deep_attrs = [ + a + for a in dir(sub_val) + if not a.startswith("__") + ][ + :3 + ] # Just first 3 + if deep_attrs: + print( + f" โ””โ”€ {sub_attr} has: {deep_attrs}" + ) + for deep_attr in deep_attrs: + try: + deep_val = getattr(sub_val, deep_attr) + print( + f" โ””โ”€ {deep_attr}: {deep_val}" + ) + except: + print( + f" โ””โ”€ {deep_attr}: " + ) + except: + pass + except Exception as e: + print(f" โ””โ”€ {sub_attr}: ") + + # For other attributes, shorter exploration + elif hasattr(val, "__dict__") or hasattr(val, "__slots__"): + sub_attrs = [a for a in dir(val) if not a.startswith("__")] + if sub_attrs: + print(f" โ””โ”€ {attr} attributes: {sub_attrs}") + except: + print(f" ctx.{attr}: ") + + print(f"=== END {label} ===\n") diff --git a/src/forge/controller/metric_main.py b/src/forge/controller/metric_main.py new file mode 100644 index 000000000..78a24cfe7 --- /dev/null +++ b/src/forge/controller/metric_main.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from dataclasses import dataclass +from typing import Any + +from monarch.actor import context, endpoint, get_or_spawn_controller + +from forge.controller.actor import ForgeActor + +from forge.controller.metric_actors import debug_context, GlobalLoggingActor + + +def push_metrics(key: str, value: Any) -> None: + """ + Push metrics to LocalLoggingActor + + Args: + key: Metric name + value: Metric value + """ + try: + # Just use the regular monarch context + ctx = context() + + # Try to get LocalLoggingActor from context + local_logging_actor = None + + if hasattr(ctx, "actor_instance") and hasattr(ctx.actor_instance, "_proc_mesh"): + proc_mesh = ctx.actor_instance._proc_mesh + if proc_mesh is not None and hasattr(proc_mesh, "_local_logger"): + local_logging_actor = proc_mesh._local_logger + print(f"โœ… Found LocalLoggingActor via context for {key}") + + if local_logging_actor: + local_logging_actor.push_metrics.broadcast(key, value) + else: + print( + f"โŒ No LocalLoggingActor found in context, dropping metric {key}={value}" + ) + debug_context(ctx, f"CONTEXT DEBUG for {key}={value}") + + except Exception as e: + print(f"โŒ push_metrics failed for {key}={value}: {e}") + import traceback + + traceback.print_exc() + + +async def flush(step: int) -> None: + """Flush all metrics globally.""" + try: + g = await get_or_spawn_controller("global_logger", GlobalLoggingActor) + await g.flush.call_one(step) + except Exception as e: + print(f"โŒ flush failed: {e}") + import traceback + + traceback.print_exc() + + +@dataclass +class Trainer(ForgeActor): + """Trainer that uses global_logger.push_metrics.""" + + def __post_init__(self): + self.step_counter = 0 + + @endpoint + async def train_step(self) -> int: + """Simulate one training step.""" + self.step_counter += 1 + push_metrics("step_counter", self.step_counter) + print(f"Trainer: Completed step {self.step_counter}") + return self.step_counter + + @endpoint + async def debug_context(self) -> None: + """Debug what the service actor can see.""" + ctx = context() + debug_context(ctx) + + +# ============================================================================ +# Main Training Loop +# ============================================================================ + + +async def continuous_training(trainer: Trainer, num_steps: int = 5): + """Run training loop with periodic flushing.""" + print(f"\ Starting training for {num_steps} steps...") + + for step in range(num_steps): + print(f"\n--- Step {step + 1} ---") + + # Run training step + await trainer.train_step.choose() + + if (step + 1) % 2 == 0: # Flush every 2 steps + print(f"๐Ÿ”„ Flushing metrics at step {step + 1}") + await flush(step + 1) + + await asyncio.sleep(0.1) + + print("โœ… Training completed!") + + +async def main(): + """Main function demonstrating the REAL issue (following your architecture).""" + + print("1. Spawning trainer service...") + service_config = {"procs_per_replica": 1, "num_replicas": 1, "with_gpus": False} + + # This should internally: + # - Call get_proc_mesh() + # - Spawn LocalLoggingActor in that process (via provisioner.py changes) + # - Register LocalLoggingActor with GlobalLoggingActor + # - Make LocalLoggingActor accessible to the Trainer via context + trainer = await Trainer.options(**service_config).as_service() + + # Debug what the service actor can see + print("\n2. Debugging service actor context...") + await trainer.debug_context.choose() + + # Test the full training loop with metrics + flushing + print("\n3. Running training loop with metrics + flushing...") + await continuous_training(trainer, num_steps=2) + + # Shutdown + print("\n4. Shutting down...") + await trainer.shutdown() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception as e: + print(f"\n Failed with error: {e}") + import traceback + + traceback.print_exc() + raise diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 27aa1293e..2860151ba 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -12,15 +12,23 @@ import uuid import monarch + +from forge.controller.metric_actors import GlobalLoggingActor, LocalLoggingActor +from forge.types import ProcessConfig from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer from monarch._src.actor.shape import NDSlice, Shape -from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host +from monarch.actor import ( + Actor, + endpoint, + get_or_spawn_controller, + HostMesh, + ProcMesh, + this_host, +) from monarch.tools import commands from monarch.tools.components import hyperactor from monarch.tools.config import Config -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -184,6 +192,26 @@ def bootstrap(gpu_ids: int): procs._host = host_mesh + # Spawn local logging actor on each process and register with global logger + try: + local_logging_actor = await procs.spawn( + "local_logging_actor", LocalLoggingActor + ) + procs._local_logger = local_logging_actor + + # Register with global logger + global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + process_name = f"proc_mesh_{id(procs)}" + await global_logger.register.call_one(local_logging_actor, process_name) + + logger.debug( + f"Spawned and registered LocalLoggingActor for {process_name}" + ) + except Exception as e: + logger.warning(f"Failed to spawn LocalLoggingActor: {e}") + # If we created a server, track so we can tear it down later. if server_name: self._server_names.append(server_name) @@ -194,6 +222,18 @@ def bootstrap(gpu_ids: int): async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" async with self._lock: + # Deregister local logger from global logger + if hasattr(proc_mesh, "_local_logger"): + try: + global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + process_name = f"proc_mesh_{id(proc_mesh)}" + await global_logger.deregister.call_one(process_name) + logger.debug(f"Deregistered LocalLoggingActor for {process_name}") + except Exception as e: + logger.warning(f"Failed to deregister LocalLoggingActor: {e}") + if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] gpu_manager.release_gpus(proc_mesh._gpu_ids) From 9f13bfbb544451ba330fd66e35f2837e0b3ca8aa Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 22 Sep 2025 14:26:07 -0700 Subject: [PATCH 02/25] it works --- src/forge/controller/__init__.py | 5 +- src/forge/controller/provisioner.py | 31 +- src/forge/controller/service/replica.py | 2 +- src/forge/controller/v3/__init__.py | 5 + src/forge/controller/v3/metric_actors.py | 165 +++++++ src/forge/controller/v3/metric_main.py | 104 +++++ src/forge/controller/v3/metrics.py | 571 +++++++++++++++++++++++ 7 files changed, 868 insertions(+), 15 deletions(-) create mode 100644 src/forge/controller/v3/__init__.py create mode 100644 src/forge/controller/v3/metric_actors.py create mode 100644 src/forge/controller/v3/metric_main.py create mode 100644 src/forge/controller/v3/metrics.py diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index 71d35c433..23a3d6804 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -3,8 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -from .actor import ForgeActor from .proc_mesh import get_proc_mesh, stop_proc_mesh @@ -12,7 +10,7 @@ # service async def spawn_actors( name: str, - actor_cls: ForgeActor, + actor_cls, cfg, processes, set_address: bool = False, @@ -28,5 +26,4 @@ async def spawn_actors( "spawn_actors", "stop_proc_mesh", "get_proc_mesh", - "ForgeActor", ] diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 2860151ba..c22d29bb5 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -12,9 +12,6 @@ import uuid import monarch - -from forge.controller.metric_actors import GlobalLoggingActor, LocalLoggingActor -from forge.types import ProcessConfig from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import ( @@ -29,6 +26,8 @@ from monarch.tools.components import hyperactor from monarch.tools.config import Config +from forge.types import ProcessConfig + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -194,23 +193,33 @@ def bootstrap(gpu_ids: int): # Spawn local logging actor on each process and register with global logger try: - local_logging_actor = await procs.spawn( - "local_logging_actor", LocalLoggingActor + from forge.controller.v3.metric_actors import ( + GlobalLoggingActor, + LocalFetcherActor, ) - procs._local_logger = local_logging_actor + + local_fetcher_actor = await procs.spawn( + "local_fetcher_actor", LocalFetcherActor + ) + procs._local_fetcher = local_fetcher_actor # Register with global logger global_logger = await get_or_spawn_controller( "global_logger", GlobalLoggingActor ) process_name = f"proc_mesh_{id(procs)}" - await global_logger.register.call_one(local_logging_actor, process_name) + await global_logger.register_fetcher.call_one( + local_fetcher_actor, process_name + ) logger.debug( - f"Spawned and registered LocalLoggingActor for {process_name}" + f"Spawned and registered LocalFetcherActor for {process_name}" ) except Exception as e: - logger.warning(f"Failed to spawn LocalLoggingActor: {e}") + logger.warning(f"Failed to spawn LocalFetcherActor: {e}") + import traceback + + traceback.print_stack() # If we created a server, track so we can tear it down later. if server_name: @@ -223,8 +232,10 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" async with self._lock: # Deregister local logger from global logger - if hasattr(proc_mesh, "_local_logger"): + if hasattr(proc_mesh, "_local_fetcher"): try: + from forge.controller.v3.metric_actors import GlobalLoggingActor + global_logger = await get_or_spawn_controller( "global_logger", GlobalLoggingActor ) diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index b84e5eec7..569f0e829 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -15,7 +15,7 @@ from monarch.actor import ActorError -from forge.controller import ForgeActor +from forge.controller.actor import ForgeActor from forge.types import ProcessConfig logger = logging.getLogger(__name__) diff --git a/src/forge/controller/v3/__init__.py b/src/forge/controller/v3/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/src/forge/controller/v3/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/forge/controller/v3/metric_actors.py b/src/forge/controller/v3/metric_actors.py new file mode 100644 index 000000000..00b33ae9c --- /dev/null +++ b/src/forge/controller/v3/metric_actors.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +from typing import Any, Dict, Optional + +from monarch.actor import Actor, current_rank, endpoint + + +class LocalFetcherActor(Actor): + @endpoint + async def log_and_reset( + self, step: int, return_state: bool = False + ) -> Optional[Dict[str, Dict[str, Any]]]: + """Log to local backends (if any), optionally return states, and reset.""" + from forge.controller.v3.metrics import MetricCollector + + collector = MetricCollector() + result = await collector.log_and_reset(step, return_state=return_state) + print( + f"๐ŸŽฏ Fetcher rank {current_rank().rank}: Flushed step {step}, returned state: {return_state}" + ) + return result + + @endpoint + async def init_collector(self, primary_backend_states: Dict[str, Dict[str, Any]]): + from forge.controller.v3.metrics import MetricCollector + + collector = MetricCollector() + await collector._init(primary_backend_states) + print(f"๐ŸŽฏ Fetcher rank {current_rank().rank}: Initialized collector") + + @endpoint + async def shutdown(self): + from forge.controller.v3.metrics import MetricCollector + + collector = MetricCollector() + await collector.shutdown() + print(f"๐ŸŽฏ Fetcher rank {current_rank().rank}: Finished all backends") + + +# GlobalLoggingActor (coordinator) +class GlobalLoggingActor(Actor): + def __init__(self): + self.fetchers: Dict[str, LocalFetcherActor] = {} + self.config: Optional[Dict[str, Any]] = None + self.global_backends: Dict[str, "Backend"] = {} + self.primary_backend_states: Dict[str, Dict[str, Any]] = {} + + @endpoint + async def init_config(self, config: Dict[str, Any]): + """Main calls this to set config and init global backends if needed.""" + self.config = config + + # Validate unique classes + classes = [b["class"] for b in config["backends"]] + if len(set(classes)) != len(classes): + raise ValueError("Duplicate backend classes in config") + + # Init global backends and states where needed + from forge.controller.v3.metrics import create_backend + + for backend_config in config["backends"]: + cls_name = backend_config["class"] + backend = create_backend( + backend_config + ) # Factory: returns instance based on type + + await backend.setup(self.config, role="global") + primary_state = backend.get_primary_state() or {} + log_per_rank = backend_config.get("log_per_rank", True) + if log_per_rank: + self.primary_backend_states[cls_name] = primary_state + if not log_per_rank: + self.global_backends[cls_name] = backend + print( + f"๐ŸŒ Global: Processed backend {cls_name} (log_per_rank: {log_per_rank})" + ) + + # Eager init collectors on all registered fetchers in parallel, passing primary states + if self.fetchers: + tasks = [ + fetcher.init_collector.call(self.primary_backend_states) + for fetcher in self.fetchers.values() + ] + await asyncio.gather(*tasks, return_exceptions=True) + print(f"๐ŸŒ Global: Initialized {len(tasks)} collectors in parallel") + + print("๐ŸŒ Global: Config set") + + @endpoint + def get_metric_logger_cfg(self) -> Optional[Dict[str, Any]]: + return self.config + + @endpoint + async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): + self.fetchers[name] = fetcher + print(f"๐ŸŒ Global: Registered {name} (total: {len(self.fetchers)})") + + @endpoint + async def flush_global(self, step: int): + if not self.fetchers: + print("๐ŸŒ Global: No fetchers") + return + + print(f"๐ŸŒ Global: Flushing step {step} across {len(self.fetchers)}") + + config = self.config + has_reduce = any(not b.get("log_per_rank", True) for b in config["backends"]) + return_state = has_reduce # Flag for reduce + + # Broadcast log_and_reset to all fetchers + results = await asyncio.gather( + *[ + f.log_and_reset.call(step, return_state=return_state) + for f in self.fetchers.values() + ], + return_exceptions=True, + ) + + if has_reduce: + # Flatten: Handle both single-process (dict/None) and multi-process (list of dicts/None) + all_local_results = [] + for res in results: + res = ( + res._values + ) # TODO: avoid using internal state. Could use items() instead, but has to parse metadata. + if isinstance(res, list): + all_local_results.extend(res) + elif res is not None: + all_local_results.append(res) + + # Filter states from results (None if not returned) + all_local_states = [r for r in all_local_results if isinstance(r, dict)] + if not all_local_states: + print("๐ŸŒ Global: No local states gathered") + return + + # Reduce + from forge.controller.v3.metrics import reduce_across_ranks + + reduced_metrics = reduce_across_ranks(all_local_states) + + # Log to each global backend + for backend_name, backend in self.global_backends.items(): + await backend.log(reduced_metrics, step) + print(f"๐ŸŒ Global: Logged reduced metrics {reduced_metrics} at step {step}") + + @endpoint + async def shutdown(self): + # Finish per-rank backends via fetchers + if self.fetchers: + tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] + await asyncio.gather(*tasks, return_exceptions=True) + print(f"๐ŸŒ Global: Finished {len(self.fetchers)} fetchers' backends") + + # Finish global backends + for backend_name, backend in self.global_backends.items(): + await backend.finish() + print(f"๐ŸŒ Global: Finished global backend {backend_name}") + + print("๐ŸŒ Global: Shutdown complete") diff --git a/src/forge/controller/v3/metric_main.py b/src/forge/controller/v3/metric_main.py new file mode 100644 index 000000000..5f81ba1b7 --- /dev/null +++ b/src/forge/controller/v3/metric_main.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import sys +import time + +from monarch.actor import current_rank, endpoint, get_or_spawn_controller + +from forge.controller.actor import ForgeActor +from forge.controller.v3.metric_actors import GlobalLoggingActor +from forge.controller.v3.metrics import push_metrics, ReductionType + + +class TrainActor(ForgeActor): + @endpoint + async def train_step(self, step: int): + rank = current_rank().rank + value = rank * 1000 + 100 * step + print(f"๐Ÿ”ง Train rank {rank}: Step {step}, loss={value}") + await push_metrics("train/loss", value) + + +class GeneratorActor(ForgeActor): + @endpoint + async def generate_step(self, step: int, substep: int): + rank = current_rank().rank + value = rank * 1000 + step * 100 + substep * 10 + print(f"๐ŸŽฏ Gen rank {rank}: Step {step}.{substep}, tokens={value}") + await push_metrics("generate/tokens", value, ReductionType.SUM) + + +# Main +async def main(mode: str = "wandb_all_log_all"): + group = f"experiment_group_{int(time.time())}" + if mode == "wandb_all_log_all": + backends = [ + {"class": "console", "log_per_rank": True}, + { + "class": "wandb", + "project": "my_project", + "group": group, + "mode": "wandb_all_log_all", + "log_per_rank": True, + }, + ] + elif mode == "wandb_rank_0_reduce_all": + backends = [ + { + "class": "wandb", + "project": "my_project", + "group": group, + "mode": "wandb_rank_0_reduce_all", + "log_per_rank": False, + }, + ] + else: # wandb_rank_0_log_all + backends = [ + { + "class": "wandb", + "project": "my_project", + "group": group, + "mode": "wandb_rank_0_log_all", + "log_per_rank": True, + }, + ] + + logging_config = { + "backends": backends, + } + service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} + + # Spawn services first (triggers registrations via provisioner hook) + trainer = await TrainActor.options(**service_config).as_service() + generator = await GeneratorActor.options(**service_config).as_service() + + # Now init config on global (inits backends eagerly across fetchers) + global_logger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) + await global_logger.init_config.call_one(logging_config) + + for i in range(3): + print(f"\n=== Global Step {i} ===") + await trainer.train_step.call(i) + for sub in range(3): + await generator.generate_step.call(i, sub) + await global_logger.flush_global.call_one(i) + + await global_logger.shutdown.call_one() + + +if __name__ == "__main__": + mode = sys.argv[1] if len(sys.argv) > 1 else "wandb_all_log_all" + valid_modes = [ + "wandb_all_log_all", + "wandb_rank_0_log_all", + "wandb_rank_0_reduce_all", + ] + if mode not in valid_modes: + print(f"Invalid mode: {mode}. Use {valid_modes}") + sys.exit(1) + asyncio.run(main(mode)) diff --git a/src/forge/controller/v3/metrics.py b/src/forge/controller/v3/metrics.py new file mode 100644 index 000000000..00059ef73 --- /dev/null +++ b/src/forge/controller/v3/metrics.py @@ -0,0 +1,571 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional + +import wandb +from monarch.actor import context, current_rank + + +# Reduction Types +class ReductionType(Enum): + MEAN = "mean" + SUM = "sum" + MAX = "max" + MIN = "min" + COUNT = "count" + + @property + def accumulator_class(self): + + mapping = { + ReductionType.MEAN: MeanAccumulator, + ReductionType.SUM: SumAccumulator, + ReductionType.MAX: MaxAccumulator, + ReductionType.MIN: MinAccumulator, + ReductionType.COUNT: CountAccumulator, + } + return mapping[self] + + +def get_actor_name_for_logging(): + """ + Extract actor information from Monarch context and return formatted name for logging. + Returns string like "{actor_type}_{replica_id[-6:]}_r{local_rank_int}" + + #TODO: this is flaky as it currently relies on string parsing. + """ + + # Add more defensive checks + ctx = context() + if ctx is None: + print("โš ๏ธ Warning: context() returned None") + return "UnknownActor_r0_l0" + + actor_instance = ctx.actor_instance + if actor_instance is None: + print("โš ๏ธ Warning: actor_instance is None") + return "UnknownActor_r0_l0" + + rank = current_rank() + if rank is None: + print("โš ๏ธ Warning: current_rank() returned None") + return "UnknownActor_r0_l0" + + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + rank_name = "UnknownActor_r0_l0" # fallback + if len(parts) >= 2: + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Extract world ID and proc rank + world_id = world_part.split("[")[0] if "[" in world_part else world_part + + # Extract clean actor name (remove "Configured" suffix if present) + if "[" in actor_part: + actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" + if actor_name.endswith("Configured"): + actor_name = actor_name[:-10] # Remove "Configured" + else: + actor_name = actor_part + + # Use last 4 characters of world_id as replica identifier + # This is deterministic, readable, and works for any number of replicas + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + # Use current_rank().rank as the local rank within the replica + local_rank = rank.rank + + rank_name = f"{actor_name}_{replica_id}_r{local_rank}" + + return rank_name + + +# Simple push +async def push_metrics( + key: str, value: Any, reduction: ReductionType = ReductionType.MEAN +) -> None: + collector = MetricCollector() + await collector.push(key, value, reduction) + + +def reduce_across_ranks( + all_local_states: List[Dict[str, Dict[str, Any]]] +) -> Dict[str, Any]: + """Reduce states across ranks per key.""" + if not all_local_states: + return {} + + # Collect unique keys across all + all_keys = set(k for states in all_local_states for k in states) + print(f"๐Ÿ”ง Reduce: Unique keys: {list(all_keys)}") + + global_metrics = {} + for key in all_keys: + metric_states = [ + states.get(key) for states in all_local_states if key in states + ] + if not metric_states: + continue + + first_red_type = metric_states[0]["reduction_type"] + # Check consistency + for state in metric_states[1:]: + if state["reduction_type"] != first_red_type: + raise ValueError( + f"Mismatched reduction types for key '{key}': {first_red_type} vs {state['reduction_type']}" + ) + + red_enum = ReductionType(first_red_type) + acc_class = red_enum.accumulator_class + reduced_value = acc_class.merge_states(metric_states) + global_metrics[key] = reduced_value + + return global_metrics + + +# Backend ABC +class Backend(ABC): + async def setup( + self, + config: Dict[str, Any], + role: str, + primary_states: Optional[Dict[str, Any]] = None, + ) -> None: + if primary_states is None: + primary_states = {} + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + async def finish(self) -> None: + pass + + def get_primary_state(self) -> Optional[Dict[str, Any]]: + """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + return None + + +class ConsoleBackend(Backend): + async def setup( + self, + config: Dict[str, Any], + role: str, + primary_states: Optional[Dict[str, Any]] = None, + ) -> None: + if primary_states is None: + primary_states = {} + print("ConsoleBackend: Initialized") + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + try: + rank = current_rank().rank + rank_str = f"RANK {rank}" + except Exception: + rank_str = "GLOBAL" + print(f"\n=== {rank_str} METRICS STEP {step} ===") + for key, value in metrics.items(): + print(f" {key}: {value}") + print("==============================\n") + + async def finish(self) -> None: + print("ConsoleBackend: Finished") + + +class WandbBackend(Backend): + def __init__(self, backend_config: Dict[str, Any]): + self.backend_config = backend_config + self.project = backend_config["project"] + self.group = backend_config.get("group", "experiment_group") + self.name = None + self.run = None + self.mode = backend_config.get("mode", "wandb_all_log_all") + + async def setup( + self, + config: Dict[str, Any], + role: str, + primary_states: Optional[Dict[str, Any]] = None, + ) -> None: + if primary_states is None: + primary_states = {} + self.name = ( + get_actor_name_for_logging() if role == "local" else "global_controller" + ) + + if self.mode == "wandb_rank_0_reduce_all" and role == "local": + # Should not init locals for reduce + print("WandbBackend: Skipped local init for reduce mode") + return + + if self.mode == "wandb_all_log_all" and role == "global": + print("WandbBackend: Skipped global init for all_log_all mode") + return + + if self.mode == "wandb_all_log_all": + self.run = wandb.init( + project=self.project, group=self.group, name=self.name + ) + print(f"WandbBackend: Separate run '{self.name}' in group '{self.group}'") + elif self.mode == "wandb_rank_0_log_all": + if role == "global": + # Primary + settings = wandb.Settings( + mode="shared", x_primary=True, x_label="controller_primary" + ) + self.run = wandb.init( + project=self.project, group=self.group, settings=settings + ) + self.run.define_metric("global_step") + self.run.define_metric("train/loss", step_metric="global_step") + self.run.define_metric("generate/tokens", step_metric="global_step") + print("๐ŸŒ Global: Defined metrics with global_step axis for shared mode") + elif role == "local": + # Secondary: Use shared_run_id from primary_states + shared_id = primary_states.get("shared_run_id") + if shared_id is None: + local_rank = current_rank().rank + raise ValueError( + f"Rank {local_rank}: Shared ID required but not provided" + ) + settings = wandb.Settings( + mode="shared", x_primary=False, x_label=self.name + ) + self.run = wandb.init( + id=shared_id, + project=self.project, + group=self.group, + settings=settings, + ) + print( + f"WandbBackend: Joined shared run '{shared_id}' as secondary with label '{self.name}'" + ) + elif self.mode == "wandb_rank_0_reduce_all" and role == "global": + self.run = wandb.init(project=self.project, group=self.group) + self.run.define_metric("global_step") + self.run.define_metric("train/loss", step_metric="global_step") + self.run.define_metric("generate/tokens", step_metric="global_step") + print("๐ŸŒ Global: Initialized single run for reduce mode") + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.run: + log_data = {**metrics, "global_step": step} + print(f"WandbBackend: About to log data: {log_data} at step {step}") + self.run.log(log_data) + print( + f"WandbBackend: Successfully logged {len(metrics)} metrics at step {step}" + ) + else: + print(f"WandbBackend: No run, skipping log for {self.name}") + + def get_primary_state(self) -> Optional[Dict[str, Any]]: + if self.run and self.mode == "wandb_rank_0_log_all": + return {"shared_run_id": self.run.id} + return None # {} for others + + async def finish(self) -> None: + if self.run: + self.run.finish() + print(f"WandbBackend {self.name}: Finished run") + + +def create_backend(backend_config: Dict[str, Any]) -> Backend: + backend_type = backend_config["class"] + if backend_type == "console": + return ConsoleBackend() + elif backend_type == "wandb": + return WandbBackend(backend_config) + else: + raise ValueError(f"Unknown backend type: {backend_type}") + + +class MetricAccumulator(ABC): + def __init__(self, reduction: ReductionType): + self.reduction_type = reduction + + @abstractmethod + def append(self, value: Any) -> None: + pass + + @abstractmethod + def get_reduced_value(self) -> Any: + pass + + @abstractmethod + def get_state(self) -> Dict[str, Any]: + pass + + @classmethod + @abstractmethod + def merge_states(cls, states: List[Dict[str, Any]]) -> Any: + pass + + @abstractmethod + def reset(self) -> None: + pass + + +class MeanAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.count += 1 + + def get_reduced_value(self) -> float: + return self.sum / self.count if self.count > 0 else 0.0 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "count": self.count, + } + + @classmethod + def merge_states(cls, states: List[Dict[str, Any]]) -> float: + if not states: + return 0.0 + total_sum = sum(s["sum"] for s in states) + total_count = sum(s["count"] for s in states) + return total_sum / total_count if total_count > 0 else 0.0 + + def reset(self) -> None: + self.sum = 0.0 + self.count = 0 + + +class SumAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.total = 0.0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.total += v + + def get_reduced_value(self) -> float: + return self.total + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "total": self.total} + + @classmethod + def merge_states(cls, states: List[Dict[str, Any]]) -> float: + if not states: + return 0.0 + return sum(s["total"] for s in states) + + def reset(self) -> None: + self.total = 0.0 + + +class MaxAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.max_val = float("-inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.max_val = max(self.max_val, v) + + def get_reduced_value(self) -> float: + return self.max_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} + + @classmethod + def merge_states(cls, states: List[Dict[str, Any]]) -> float: + if not states: + return float("-inf") + return max(s["max_val"] for s in states) + + def reset(self) -> None: + self.max_val = float("-inf") + + +class MinAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.min_val = float("inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.min_val = min(self.min_val, v) + + def get_reduced_value(self) -> float: + return self.min_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} + + @classmethod + def merge_states(cls, states: List[Dict[str, Any]]) -> float: + if not states: + return float("inf") + return min(s["min_val"] for s in states) + + def reset(self) -> None: + self.min_val = float("inf") + + +class CountAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.count = 0 + + def append(self, value: Any) -> None: + self.count += 1 + + def get_reduced_value(self) -> int: + return self.count + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "count": self.count} + + @classmethod + def merge_states(cls, states: List[Dict[str, Any]]) -> int: + if not states: + return 0 + return sum(s["count"] for s in states) + + def reset(self) -> None: + self.count = 0 + + +def create_accumulator(reduction: ReductionType) -> MetricAccumulator: + acc_class = reduction.accumulator_class + return acc_class(reduction) + + +class MetricCollector: + _instances: Dict[int, "MetricCollector"] = {} + + def __new__(cls): + rank = current_rank().rank + if rank not in cls._instances: + inst = super().__new__(cls) + cls._instances[rank] = inst + inst._singleton_rank = rank + else: + inst = cls._instances[rank] + if inst._singleton_rank != rank: + raise ValueError( + f"Singleton expected rank {inst._singleton_rank}, but saw {rank}" + ) + return inst + + def __init__(self): + if hasattr(self, "_initialized_sync"): + return + self._initialized_sync = True + self.accumulators: Dict[str, MetricAccumulator] = {} + self.backends: List[Backend] = [] + self._initialized_async = False + self.rank = current_rank().rank + print(f"๐Ÿ”ง MetricCollector rank {self.rank}: Singleton initialized (unique)") + + async def _init(self, primary_backend_states: Dict[str, Dict[str, Any]]): + if self._initialized_async: + return + + from monarch.actor import get_or_spawn_controller + + from forge.controller.v3.metric_actors import GlobalLoggingActor + + global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + + config = await global_logger.get_metric_logger_cfg.call_one() + if config is None: + raise ValueError(f"Rank {self.rank}: Config not setโ€”call init_config first") + + # Init local backends only if log_per_rank=True, inject primary states + for backend_config in config["backends"]: + if not backend_config.get("log_per_rank", True): + continue # Skip globals/reduce + cls_name = backend_config["class"] + primary_state = primary_backend_states.get(cls_name, {}) + backend = create_backend(backend_config) + await backend.setup(config, role="local", primary_states=primary_state) + self.backends.append(backend) + print(f"๐Ÿ”ง Collector rank {self.rank}: Initialized local backend {cls_name}") + + self._initialized_async = True + print(f"๐Ÿ”ง MetricCollector rank {self.rank}: Async initialization complete") + + async def push( + self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN + ): + # Assume eager init; fallback to lazy + if not self._initialized_async: + raise ValueError("Collector not initializedโ€”call init first") + if key not in self.accumulators: + self.accumulators[key] = create_accumulator(reduction) + self.accumulators[key].append(value) + print( + f"๐Ÿ”ง Collector rank {self.rank}: Pushed {key}={value} (reduction={reduction.value})" + ) + + async def log_and_reset( + self, step: int, return_state: bool = False + ) -> Optional[Dict[str, Dict[str, Any]]]: + """Log to local backends (if any), optionally return states, and reset.""" + if not self._initialized_async: + raise ValueError("Collector not initializedโ€”call init first") + if not self.accumulators: + print(f"๐Ÿ”ง Collector rank {self.rank}: No metrics to flush for step {step}") + return {} if return_state else None + + print( + f"๐Ÿ”ง Collector rank {self.rank}: Accumulators before flush: {list(self.accumulators.keys())}" + ) + + # Snapshot states and reset immediately + states = {key: acc.get_state() for key, acc in self.accumulators.items()} + for acc in list(self.accumulators.values()): + acc.reset() + self.accumulators.clear() + + # Derive metrics from states if needed + if self.backends: + metrics = {} + for key, state in states.items(): + red_enum = ReductionType(state["reduction_type"]) + acc_class = red_enum.accumulator_class + metrics[key] = acc_class.merge_states([state]) + print(f"๐Ÿ”ง Collector rank {self.rank}: Metrics: {metrics}") + + # Log to local backends + for backend in self.backends: + await backend.log(metrics, step) + + if return_state: + print(f"๐Ÿ”ง Collector rank {self.rank}: States: {list(states.keys())}") + + print(f"๐Ÿ”ง Collector rank {self.rank}: Flushed and reset for step {step}") + return states if return_state else None + + async def shutdown(self): + """Shutdown backends if initialized.""" + if not self._initialized_async: + print(f"๐Ÿ”ง Collector rank {self.rank}: Not initialized, skipping shutdown") + return + for backend in self.backends: + await backend.finish() + print(f"๐Ÿ”ง Collector rank {self.rank}: Shutdown complete") From 4b9aadadde5128511a37a46f27c543170bb8af14 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 22 Sep 2025 14:36:46 -0700 Subject: [PATCH 03/25] delete old files --- src/forge/controller/metric_actors.py | 211 -------------------------- src/forge/controller/metric_main.py | 146 ------------------ 2 files changed, 357 deletions(-) delete mode 100644 src/forge/controller/metric_actors.py delete mode 100644 src/forge/controller/metric_main.py diff --git a/src/forge/controller/metric_actors.py b/src/forge/controller/metric_actors.py deleted file mode 100644 index 3d94eefd8..000000000 --- a/src/forge/controller/metric_actors.py +++ /dev/null @@ -1,211 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import asyncio -from collections import defaultdict -from typing import Any, Dict, List - -from monarch.actor import Actor, endpoint - - -# ============================================================================ -# LocalLoggingActor -# ============================================================================ -class LocalLoggingActor(Actor): - """Local logging actor that accumulates metrics within a process.""" - - def __init__(self): - self._metrics = defaultdict(list) - - @endpoint - async def get_metrics(self) -> Dict[str, List[Any]]: - """Get all accumulated metrics (called by GlobalLoggingActor).""" - # Return copy and reset for next collection - result = dict(self._metrics) - self._metrics.clear() - print(f"LocalLoggingActor: Returning {len(result)} metric keys") - return result - - @endpoint - def push_metrics( - self, key: str, value: Any - ) -> None: # Note: not async for broadcast - """Store a metric value (called by service actors).""" - self._metrics[key].append(value) - print(f"LocalLoggingActor: Stored {key}={value}") - - -# ============================================================================ -# Backend System -# ============================================================================ -class Backend: - """Base class for logging backends.""" - - def push(self, metrics: Dict[str, Any], step: int) -> None: - pass - - -class ConsoleBackend(Backend): - """Simple console backend for testing.""" - - def push(self, metrics: Dict[str, Any], step: int) -> None: - print(f"\n=== METRICS STEP {step} ===") - for key, value in metrics.items(): - print(f" {key}: {value}") - print("========================\n") - - -# ============================================================================ -# GlobalLoggingActor -# ============================================================================ - - -class GlobalLoggingActor(Actor): - """Global logger that coordinates across all processes.""" - - def __init__(self): - self._loggers: Dict[str, LocalLoggingActor] = {} - self._backends: List[Backend] = [ConsoleBackend()] # Default console backend - - @endpoint - async def register(self, local_actor: LocalLoggingActor, name: str) -> None: - """Register a LocalLoggingActor from a process.""" - self._loggers[name] = local_actor - print(f"GlobalLoggingActor: Registered {name}") - - @endpoint - async def deregister(self, name: str) -> None: - """Deregister a LocalLoggingActor.""" - if name in self._loggers: - del self._loggers[name] - print(f"GlobalLoggingActor: Deregistered {name}") - - @endpoint - async def flush(self, step: int) -> None: - """Collect metrics from all processes and send to backends.""" - if not self._loggers: - print("GlobalLoggingActor: No loggers registered") - return - - print(f"GlobalLoggingActor: Flushing metrics for step {step}") - - # Collect from all local loggers - metrics_list = await asyncio.gather( - *[actor.get_metrics.call_one() for actor in self._loggers.values()] - ) - - # Simple aggregation - just combine all metrics - print("metrics_list", metrics_list) - all_metrics = {} - for metrics in metrics_list: - for key, values in metrics.items(): - if key not in all_metrics: - all_metrics[key] = [] - all_metrics[key].extend(values) - - # Send to all backends - for backend in self._backends: - backend.push(all_metrics, step) - - -# =========================================================================== - - -def debug_context(ctx, label: str = "DEBUG") -> None: - """ - Utility function to fully debug a context object and its nested attributes. - - Args: - ctx: The context object to debug - label: Label for this debug session - """ - print(f"\n=== {label} ===") - print(f"Context type: {type(ctx)}") - print(f"Context dir: {[attr for attr in dir(ctx) if not attr.startswith('__')]}") - - # Print all attributes - for attr in dir(ctx): - if not attr.startswith("__"): - try: - val = getattr(ctx, attr) - print(f" ctx.{attr}: {val} (type: {type(val)})") - - # f this is actor_instance, explore it - if attr == "actor_instance": - print(f" ๐Ÿ” DEEP DIVE INTO ACTOR_INSTANCE:") - print(f" Type: {type(val)}") - - # Get all attributes - all_attrs = [a for a in dir(val) if not a.startswith("__")] - print(f" All attributes: {all_attrs}") - - # Check for anything related to proc, mesh, logger, etc. - interesting_attrs = [ - a - for a in all_attrs - if any( - keyword in a.lower() - for keyword in [ - "proc", - "mesh", - "log", - "local", - "spawn", - "process", - "actor", - ] - ) - ] - print(f" Interesting attributes: {interesting_attrs}") - - # Print ALL attributes with their values - for sub_attr in all_attrs: - try: - sub_val = getattr(val, sub_attr) - print( - f" โ””โ”€ {sub_attr}: {sub_val} (type: {type(sub_val)})" - ) - - # If it's an object, go one level deeper - if hasattr(sub_val, "__dict__") or hasattr( - sub_val, "__slots__" - ): - try: - deep_attrs = [ - a - for a in dir(sub_val) - if not a.startswith("__") - ][ - :3 - ] # Just first 3 - if deep_attrs: - print( - f" โ””โ”€ {sub_attr} has: {deep_attrs}" - ) - for deep_attr in deep_attrs: - try: - deep_val = getattr(sub_val, deep_attr) - print( - f" โ””โ”€ {deep_attr}: {deep_val}" - ) - except: - print( - f" โ””โ”€ {deep_attr}: " - ) - except: - pass - except Exception as e: - print(f" โ””โ”€ {sub_attr}: ") - - # For other attributes, shorter exploration - elif hasattr(val, "__dict__") or hasattr(val, "__slots__"): - sub_attrs = [a for a in dir(val) if not a.startswith("__")] - if sub_attrs: - print(f" โ””โ”€ {attr} attributes: {sub_attrs}") - except: - print(f" ctx.{attr}: ") - - print(f"=== END {label} ===\n") diff --git a/src/forge/controller/metric_main.py b/src/forge/controller/metric_main.py deleted file mode 100644 index 78a24cfe7..000000000 --- a/src/forge/controller/metric_main.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import asyncio -from dataclasses import dataclass -from typing import Any - -from monarch.actor import context, endpoint, get_or_spawn_controller - -from forge.controller.actor import ForgeActor - -from forge.controller.metric_actors import debug_context, GlobalLoggingActor - - -def push_metrics(key: str, value: Any) -> None: - """ - Push metrics to LocalLoggingActor - - Args: - key: Metric name - value: Metric value - """ - try: - # Just use the regular monarch context - ctx = context() - - # Try to get LocalLoggingActor from context - local_logging_actor = None - - if hasattr(ctx, "actor_instance") and hasattr(ctx.actor_instance, "_proc_mesh"): - proc_mesh = ctx.actor_instance._proc_mesh - if proc_mesh is not None and hasattr(proc_mesh, "_local_logger"): - local_logging_actor = proc_mesh._local_logger - print(f"โœ… Found LocalLoggingActor via context for {key}") - - if local_logging_actor: - local_logging_actor.push_metrics.broadcast(key, value) - else: - print( - f"โŒ No LocalLoggingActor found in context, dropping metric {key}={value}" - ) - debug_context(ctx, f"CONTEXT DEBUG for {key}={value}") - - except Exception as e: - print(f"โŒ push_metrics failed for {key}={value}: {e}") - import traceback - - traceback.print_exc() - - -async def flush(step: int) -> None: - """Flush all metrics globally.""" - try: - g = await get_or_spawn_controller("global_logger", GlobalLoggingActor) - await g.flush.call_one(step) - except Exception as e: - print(f"โŒ flush failed: {e}") - import traceback - - traceback.print_exc() - - -@dataclass -class Trainer(ForgeActor): - """Trainer that uses global_logger.push_metrics.""" - - def __post_init__(self): - self.step_counter = 0 - - @endpoint - async def train_step(self) -> int: - """Simulate one training step.""" - self.step_counter += 1 - push_metrics("step_counter", self.step_counter) - print(f"Trainer: Completed step {self.step_counter}") - return self.step_counter - - @endpoint - async def debug_context(self) -> None: - """Debug what the service actor can see.""" - ctx = context() - debug_context(ctx) - - -# ============================================================================ -# Main Training Loop -# ============================================================================ - - -async def continuous_training(trainer: Trainer, num_steps: int = 5): - """Run training loop with periodic flushing.""" - print(f"\ Starting training for {num_steps} steps...") - - for step in range(num_steps): - print(f"\n--- Step {step + 1} ---") - - # Run training step - await trainer.train_step.choose() - - if (step + 1) % 2 == 0: # Flush every 2 steps - print(f"๐Ÿ”„ Flushing metrics at step {step + 1}") - await flush(step + 1) - - await asyncio.sleep(0.1) - - print("โœ… Training completed!") - - -async def main(): - """Main function demonstrating the REAL issue (following your architecture).""" - - print("1. Spawning trainer service...") - service_config = {"procs_per_replica": 1, "num_replicas": 1, "with_gpus": False} - - # This should internally: - # - Call get_proc_mesh() - # - Spawn LocalLoggingActor in that process (via provisioner.py changes) - # - Register LocalLoggingActor with GlobalLoggingActor - # - Make LocalLoggingActor accessible to the Trainer via context - trainer = await Trainer.options(**service_config).as_service() - - # Debug what the service actor can see - print("\n2. Debugging service actor context...") - await trainer.debug_context.choose() - - # Test the full training loop with metrics + flushing - print("\n3. Running training loop with metrics + flushing...") - await continuous_training(trainer, num_steps=2) - - # Shutdown - print("\n4. Shutting down...") - await trainer.shutdown() - - -if __name__ == "__main__": - try: - asyncio.run(main()) - except Exception as e: - print(f"\n Failed with error: {e}") - import traceback - - traceback.print_exc() - raise From 2864324d32677111831f69c9ba99015b0aff3b43 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 22 Sep 2025 22:35:02 -0700 Subject: [PATCH 04/25] refactoring + docstrings --- .../toy_metrics/main.py | 17 +- src/forge/controller/provisioner.py | 65 +- src/forge/controller/v3/metric_actors.py | 165 ----- src/forge/controller/v3/metrics.py | 571 ---------------- .../v3 => observability}/__init__.py | 0 src/forge/observability/metric_actors.py | 193 ++++++ src/forge/observability/metrics.py | 633 ++++++++++++++++++ 7 files changed, 863 insertions(+), 781 deletions(-) rename src/forge/controller/v3/metric_main.py => apps/toy_metrics/main.py (87%) delete mode 100644 src/forge/controller/v3/metric_actors.py delete mode 100644 src/forge/controller/v3/metrics.py rename src/forge/{controller/v3 => observability}/__init__.py (100%) create mode 100644 src/forge/observability/metric_actors.py create mode 100644 src/forge/observability/metrics.py diff --git a/src/forge/controller/v3/metric_main.py b/apps/toy_metrics/main.py similarity index 87% rename from src/forge/controller/v3/metric_main.py rename to apps/toy_metrics/main.py index 5f81ba1b7..586d6780d 100644 --- a/src/forge/controller/v3/metric_main.py +++ b/apps/toy_metrics/main.py @@ -5,14 +5,18 @@ # LICENSE file in the root directory of this source tree. import asyncio + +import logging import sys import time +from forge.controller.actor import ForgeActor +from forge.observability.metric_actors import GlobalLoggingActor +from forge.observability.metrics import record_metric, ReductionType + from monarch.actor import current_rank, endpoint, get_or_spawn_controller -from forge.controller.actor import ForgeActor -from forge.controller.v3.metric_actors import GlobalLoggingActor -from forge.controller.v3.metrics import push_metrics, ReductionType +logging.basicConfig(level=logging.INFO) class TrainActor(ForgeActor): @@ -21,7 +25,7 @@ async def train_step(self, step: int): rank = current_rank().rank value = rank * 1000 + 100 * step print(f"๐Ÿ”ง Train rank {rank}: Step {step}, loss={value}") - await push_metrics("train/loss", value) + await record_metric("train/loss", value) class GeneratorActor(ForgeActor): @@ -30,7 +34,7 @@ async def generate_step(self, step: int, substep: int): rank = current_rank().rank value = rank * 1000 + step * 100 + substep * 10 print(f"๐ŸŽฏ Gen rank {rank}: Step {step}.{substep}, tokens={value}") - await push_metrics("generate/tokens", value, ReductionType.SUM) + await record_metric("generate/tokens", value, ReductionType.SUM) # Main @@ -49,6 +53,7 @@ async def main(mode: str = "wandb_all_log_all"): ] elif mode == "wandb_rank_0_reduce_all": backends = [ + {"class": "console", "log_per_rank": False}, { "class": "wandb", "project": "my_project", @@ -71,7 +76,7 @@ async def main(mode: str = "wandb_all_log_all"): logging_config = { "backends": backends, } - service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} + service_config = {"procs_per_replica": 2, "num_replicas": 2, "with_gpus": False} # Spawn services first (triggers registrations via provisioner hook) trainer = await TrainActor.options(**service_config).as_service() diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index c22d29bb5..c2a7be328 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -126,6 +126,24 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: server_name, ) + async def _setup_logging(self, procs: ProcMesh) -> None: + """Spawn and register local fetcher for metrics on each process.""" + from forge.observability.metric_actors import ( + GlobalLoggingActor, + LocalFetcherActor, + ) + + local_fetcher_actor = await procs.spawn( + "local_fetcher_actor", LocalFetcherActor + ) + procs._local_fetcher = local_fetcher_actor + + global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + process_name = f"proc_mesh_{id(procs)}" + await global_logger.register_fetcher.call_one(local_fetcher_actor, process_name) + async def get_proc_mesh( self, num_procs: int, with_gpus: bool = False, num_hosts: int | None = None ): @@ -192,34 +210,7 @@ def bootstrap(gpu_ids: int): procs._host = host_mesh # Spawn local logging actor on each process and register with global logger - try: - from forge.controller.v3.metric_actors import ( - GlobalLoggingActor, - LocalFetcherActor, - ) - - local_fetcher_actor = await procs.spawn( - "local_fetcher_actor", LocalFetcherActor - ) - procs._local_fetcher = local_fetcher_actor - - # Register with global logger - global_logger = await get_or_spawn_controller( - "global_logger", GlobalLoggingActor - ) - process_name = f"proc_mesh_{id(procs)}" - await global_logger.register_fetcher.call_one( - local_fetcher_actor, process_name - ) - - logger.debug( - f"Spawned and registered LocalFetcherActor for {process_name}" - ) - except Exception as e: - logger.warning(f"Failed to spawn LocalFetcherActor: {e}") - import traceback - - traceback.print_stack() + await self._setup_logging(procs) # If we created a server, track so we can tear it down later. if server_name: @@ -233,17 +224,13 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async with self._lock: # Deregister local logger from global logger if hasattr(proc_mesh, "_local_fetcher"): - try: - from forge.controller.v3.metric_actors import GlobalLoggingActor - - global_logger = await get_or_spawn_controller( - "global_logger", GlobalLoggingActor - ) - process_name = f"proc_mesh_{id(proc_mesh)}" - await global_logger.deregister.call_one(process_name) - logger.debug(f"Deregistered LocalLoggingActor for {process_name}") - except Exception as e: - logger.warning(f"Failed to deregister LocalLoggingActor: {e}") + from forge.observability.metric_actors import GlobalLoggingActor + + global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + process_name = f"proc_mesh_{id(proc_mesh)}" + await global_logger.deregister.call_one(process_name) if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] diff --git a/src/forge/controller/v3/metric_actors.py b/src/forge/controller/v3/metric_actors.py deleted file mode 100644 index 00b33ae9c..000000000 --- a/src/forge/controller/v3/metric_actors.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import asyncio -from typing import Any, Dict, Optional - -from monarch.actor import Actor, current_rank, endpoint - - -class LocalFetcherActor(Actor): - @endpoint - async def log_and_reset( - self, step: int, return_state: bool = False - ) -> Optional[Dict[str, Dict[str, Any]]]: - """Log to local backends (if any), optionally return states, and reset.""" - from forge.controller.v3.metrics import MetricCollector - - collector = MetricCollector() - result = await collector.log_and_reset(step, return_state=return_state) - print( - f"๐ŸŽฏ Fetcher rank {current_rank().rank}: Flushed step {step}, returned state: {return_state}" - ) - return result - - @endpoint - async def init_collector(self, primary_backend_states: Dict[str, Dict[str, Any]]): - from forge.controller.v3.metrics import MetricCollector - - collector = MetricCollector() - await collector._init(primary_backend_states) - print(f"๐ŸŽฏ Fetcher rank {current_rank().rank}: Initialized collector") - - @endpoint - async def shutdown(self): - from forge.controller.v3.metrics import MetricCollector - - collector = MetricCollector() - await collector.shutdown() - print(f"๐ŸŽฏ Fetcher rank {current_rank().rank}: Finished all backends") - - -# GlobalLoggingActor (coordinator) -class GlobalLoggingActor(Actor): - def __init__(self): - self.fetchers: Dict[str, LocalFetcherActor] = {} - self.config: Optional[Dict[str, Any]] = None - self.global_backends: Dict[str, "Backend"] = {} - self.primary_backend_states: Dict[str, Dict[str, Any]] = {} - - @endpoint - async def init_config(self, config: Dict[str, Any]): - """Main calls this to set config and init global backends if needed.""" - self.config = config - - # Validate unique classes - classes = [b["class"] for b in config["backends"]] - if len(set(classes)) != len(classes): - raise ValueError("Duplicate backend classes in config") - - # Init global backends and states where needed - from forge.controller.v3.metrics import create_backend - - for backend_config in config["backends"]: - cls_name = backend_config["class"] - backend = create_backend( - backend_config - ) # Factory: returns instance based on type - - await backend.setup(self.config, role="global") - primary_state = backend.get_primary_state() or {} - log_per_rank = backend_config.get("log_per_rank", True) - if log_per_rank: - self.primary_backend_states[cls_name] = primary_state - if not log_per_rank: - self.global_backends[cls_name] = backend - print( - f"๐ŸŒ Global: Processed backend {cls_name} (log_per_rank: {log_per_rank})" - ) - - # Eager init collectors on all registered fetchers in parallel, passing primary states - if self.fetchers: - tasks = [ - fetcher.init_collector.call(self.primary_backend_states) - for fetcher in self.fetchers.values() - ] - await asyncio.gather(*tasks, return_exceptions=True) - print(f"๐ŸŒ Global: Initialized {len(tasks)} collectors in parallel") - - print("๐ŸŒ Global: Config set") - - @endpoint - def get_metric_logger_cfg(self) -> Optional[Dict[str, Any]]: - return self.config - - @endpoint - async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): - self.fetchers[name] = fetcher - print(f"๐ŸŒ Global: Registered {name} (total: {len(self.fetchers)})") - - @endpoint - async def flush_global(self, step: int): - if not self.fetchers: - print("๐ŸŒ Global: No fetchers") - return - - print(f"๐ŸŒ Global: Flushing step {step} across {len(self.fetchers)}") - - config = self.config - has_reduce = any(not b.get("log_per_rank", True) for b in config["backends"]) - return_state = has_reduce # Flag for reduce - - # Broadcast log_and_reset to all fetchers - results = await asyncio.gather( - *[ - f.log_and_reset.call(step, return_state=return_state) - for f in self.fetchers.values() - ], - return_exceptions=True, - ) - - if has_reduce: - # Flatten: Handle both single-process (dict/None) and multi-process (list of dicts/None) - all_local_results = [] - for res in results: - res = ( - res._values - ) # TODO: avoid using internal state. Could use items() instead, but has to parse metadata. - if isinstance(res, list): - all_local_results.extend(res) - elif res is not None: - all_local_results.append(res) - - # Filter states from results (None if not returned) - all_local_states = [r for r in all_local_results if isinstance(r, dict)] - if not all_local_states: - print("๐ŸŒ Global: No local states gathered") - return - - # Reduce - from forge.controller.v3.metrics import reduce_across_ranks - - reduced_metrics = reduce_across_ranks(all_local_states) - - # Log to each global backend - for backend_name, backend in self.global_backends.items(): - await backend.log(reduced_metrics, step) - print(f"๐ŸŒ Global: Logged reduced metrics {reduced_metrics} at step {step}") - - @endpoint - async def shutdown(self): - # Finish per-rank backends via fetchers - if self.fetchers: - tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] - await asyncio.gather(*tasks, return_exceptions=True) - print(f"๐ŸŒ Global: Finished {len(self.fetchers)} fetchers' backends") - - # Finish global backends - for backend_name, backend in self.global_backends.items(): - await backend.finish() - print(f"๐ŸŒ Global: Finished global backend {backend_name}") - - print("๐ŸŒ Global: Shutdown complete") diff --git a/src/forge/controller/v3/metrics.py b/src/forge/controller/v3/metrics.py deleted file mode 100644 index 00059ef73..000000000 --- a/src/forge/controller/v3/metrics.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Dict, List, Optional - -import wandb -from monarch.actor import context, current_rank - - -# Reduction Types -class ReductionType(Enum): - MEAN = "mean" - SUM = "sum" - MAX = "max" - MIN = "min" - COUNT = "count" - - @property - def accumulator_class(self): - - mapping = { - ReductionType.MEAN: MeanAccumulator, - ReductionType.SUM: SumAccumulator, - ReductionType.MAX: MaxAccumulator, - ReductionType.MIN: MinAccumulator, - ReductionType.COUNT: CountAccumulator, - } - return mapping[self] - - -def get_actor_name_for_logging(): - """ - Extract actor information from Monarch context and return formatted name for logging. - Returns string like "{actor_type}_{replica_id[-6:]}_r{local_rank_int}" - - #TODO: this is flaky as it currently relies on string parsing. - """ - - # Add more defensive checks - ctx = context() - if ctx is None: - print("โš ๏ธ Warning: context() returned None") - return "UnknownActor_r0_l0" - - actor_instance = ctx.actor_instance - if actor_instance is None: - print("โš ๏ธ Warning: actor_instance is None") - return "UnknownActor_r0_l0" - - rank = current_rank() - if rank is None: - print("โš ๏ธ Warning: current_rank() returned None") - return "UnknownActor_r0_l0" - - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - rank_name = "UnknownActor_r0_l0" # fallback - if len(parts) >= 2: - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Extract world ID and proc rank - world_id = world_part.split("[")[0] if "[" in world_part else world_part - - # Extract clean actor name (remove "Configured" suffix if present) - if "[" in actor_part: - actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" - if actor_name.endswith("Configured"): - actor_name = actor_name[:-10] # Remove "Configured" - else: - actor_name = actor_part - - # Use last 4 characters of world_id as replica identifier - # This is deterministic, readable, and works for any number of replicas - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - # Use current_rank().rank as the local rank within the replica - local_rank = rank.rank - - rank_name = f"{actor_name}_{replica_id}_r{local_rank}" - - return rank_name - - -# Simple push -async def push_metrics( - key: str, value: Any, reduction: ReductionType = ReductionType.MEAN -) -> None: - collector = MetricCollector() - await collector.push(key, value, reduction) - - -def reduce_across_ranks( - all_local_states: List[Dict[str, Dict[str, Any]]] -) -> Dict[str, Any]: - """Reduce states across ranks per key.""" - if not all_local_states: - return {} - - # Collect unique keys across all - all_keys = set(k for states in all_local_states for k in states) - print(f"๐Ÿ”ง Reduce: Unique keys: {list(all_keys)}") - - global_metrics = {} - for key in all_keys: - metric_states = [ - states.get(key) for states in all_local_states if key in states - ] - if not metric_states: - continue - - first_red_type = metric_states[0]["reduction_type"] - # Check consistency - for state in metric_states[1:]: - if state["reduction_type"] != first_red_type: - raise ValueError( - f"Mismatched reduction types for key '{key}': {first_red_type} vs {state['reduction_type']}" - ) - - red_enum = ReductionType(first_red_type) - acc_class = red_enum.accumulator_class - reduced_value = acc_class.merge_states(metric_states) - global_metrics[key] = reduced_value - - return global_metrics - - -# Backend ABC -class Backend(ABC): - async def setup( - self, - config: Dict[str, Any], - role: str, - primary_states: Optional[Dict[str, Any]] = None, - ) -> None: - if primary_states is None: - primary_states = {} - pass - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - pass - - async def finish(self) -> None: - pass - - def get_primary_state(self) -> Optional[Dict[str, Any]]: - """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" - return None - - -class ConsoleBackend(Backend): - async def setup( - self, - config: Dict[str, Any], - role: str, - primary_states: Optional[Dict[str, Any]] = None, - ) -> None: - if primary_states is None: - primary_states = {} - print("ConsoleBackend: Initialized") - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - try: - rank = current_rank().rank - rank_str = f"RANK {rank}" - except Exception: - rank_str = "GLOBAL" - print(f"\n=== {rank_str} METRICS STEP {step} ===") - for key, value in metrics.items(): - print(f" {key}: {value}") - print("==============================\n") - - async def finish(self) -> None: - print("ConsoleBackend: Finished") - - -class WandbBackend(Backend): - def __init__(self, backend_config: Dict[str, Any]): - self.backend_config = backend_config - self.project = backend_config["project"] - self.group = backend_config.get("group", "experiment_group") - self.name = None - self.run = None - self.mode = backend_config.get("mode", "wandb_all_log_all") - - async def setup( - self, - config: Dict[str, Any], - role: str, - primary_states: Optional[Dict[str, Any]] = None, - ) -> None: - if primary_states is None: - primary_states = {} - self.name = ( - get_actor_name_for_logging() if role == "local" else "global_controller" - ) - - if self.mode == "wandb_rank_0_reduce_all" and role == "local": - # Should not init locals for reduce - print("WandbBackend: Skipped local init for reduce mode") - return - - if self.mode == "wandb_all_log_all" and role == "global": - print("WandbBackend: Skipped global init for all_log_all mode") - return - - if self.mode == "wandb_all_log_all": - self.run = wandb.init( - project=self.project, group=self.group, name=self.name - ) - print(f"WandbBackend: Separate run '{self.name}' in group '{self.group}'") - elif self.mode == "wandb_rank_0_log_all": - if role == "global": - # Primary - settings = wandb.Settings( - mode="shared", x_primary=True, x_label="controller_primary" - ) - self.run = wandb.init( - project=self.project, group=self.group, settings=settings - ) - self.run.define_metric("global_step") - self.run.define_metric("train/loss", step_metric="global_step") - self.run.define_metric("generate/tokens", step_metric="global_step") - print("๐ŸŒ Global: Defined metrics with global_step axis for shared mode") - elif role == "local": - # Secondary: Use shared_run_id from primary_states - shared_id = primary_states.get("shared_run_id") - if shared_id is None: - local_rank = current_rank().rank - raise ValueError( - f"Rank {local_rank}: Shared ID required but not provided" - ) - settings = wandb.Settings( - mode="shared", x_primary=False, x_label=self.name - ) - self.run = wandb.init( - id=shared_id, - project=self.project, - group=self.group, - settings=settings, - ) - print( - f"WandbBackend: Joined shared run '{shared_id}' as secondary with label '{self.name}'" - ) - elif self.mode == "wandb_rank_0_reduce_all" and role == "global": - self.run = wandb.init(project=self.project, group=self.group) - self.run.define_metric("global_step") - self.run.define_metric("train/loss", step_metric="global_step") - self.run.define_metric("generate/tokens", step_metric="global_step") - print("๐ŸŒ Global: Initialized single run for reduce mode") - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - if self.run: - log_data = {**metrics, "global_step": step} - print(f"WandbBackend: About to log data: {log_data} at step {step}") - self.run.log(log_data) - print( - f"WandbBackend: Successfully logged {len(metrics)} metrics at step {step}" - ) - else: - print(f"WandbBackend: No run, skipping log for {self.name}") - - def get_primary_state(self) -> Optional[Dict[str, Any]]: - if self.run and self.mode == "wandb_rank_0_log_all": - return {"shared_run_id": self.run.id} - return None # {} for others - - async def finish(self) -> None: - if self.run: - self.run.finish() - print(f"WandbBackend {self.name}: Finished run") - - -def create_backend(backend_config: Dict[str, Any]) -> Backend: - backend_type = backend_config["class"] - if backend_type == "console": - return ConsoleBackend() - elif backend_type == "wandb": - return WandbBackend(backend_config) - else: - raise ValueError(f"Unknown backend type: {backend_type}") - - -class MetricAccumulator(ABC): - def __init__(self, reduction: ReductionType): - self.reduction_type = reduction - - @abstractmethod - def append(self, value: Any) -> None: - pass - - @abstractmethod - def get_reduced_value(self) -> Any: - pass - - @abstractmethod - def get_state(self) -> Dict[str, Any]: - pass - - @classmethod - @abstractmethod - def merge_states(cls, states: List[Dict[str, Any]]) -> Any: - pass - - @abstractmethod - def reset(self) -> None: - pass - - -class MeanAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.sum = 0.0 - self.count = 0 - - def append(self, value: Any) -> None: - v = float(value.item() if hasattr(value, "item") else value) - self.sum += v - self.count += 1 - - def get_reduced_value(self) -> float: - return self.sum / self.count if self.count > 0 else 0.0 - - def get_state(self) -> Dict[str, Any]: - return { - "reduction_type": self.reduction_type.value, - "sum": self.sum, - "count": self.count, - } - - @classmethod - def merge_states(cls, states: List[Dict[str, Any]]) -> float: - if not states: - return 0.0 - total_sum = sum(s["sum"] for s in states) - total_count = sum(s["count"] for s in states) - return total_sum / total_count if total_count > 0 else 0.0 - - def reset(self) -> None: - self.sum = 0.0 - self.count = 0 - - -class SumAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.total = 0.0 - - def append(self, value: Any) -> None: - v = float(value.item() if hasattr(value, "item") else value) - self.total += v - - def get_reduced_value(self) -> float: - return self.total - - def get_state(self) -> Dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "total": self.total} - - @classmethod - def merge_states(cls, states: List[Dict[str, Any]]) -> float: - if not states: - return 0.0 - return sum(s["total"] for s in states) - - def reset(self) -> None: - self.total = 0.0 - - -class MaxAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.max_val = float("-inf") - - def append(self, value: Any) -> None: - v = float(value.item() if hasattr(value, "item") else value) - self.max_val = max(self.max_val, v) - - def get_reduced_value(self) -> float: - return self.max_val - - def get_state(self) -> Dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} - - @classmethod - def merge_states(cls, states: List[Dict[str, Any]]) -> float: - if not states: - return float("-inf") - return max(s["max_val"] for s in states) - - def reset(self) -> None: - self.max_val = float("-inf") - - -class MinAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.min_val = float("inf") - - def append(self, value: Any) -> None: - v = float(value.item() if hasattr(value, "item") else value) - self.min_val = min(self.min_val, v) - - def get_reduced_value(self) -> float: - return self.min_val - - def get_state(self) -> Dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} - - @classmethod - def merge_states(cls, states: List[Dict[str, Any]]) -> float: - if not states: - return float("inf") - return min(s["min_val"] for s in states) - - def reset(self) -> None: - self.min_val = float("inf") - - -class CountAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.count = 0 - - def append(self, value: Any) -> None: - self.count += 1 - - def get_reduced_value(self) -> int: - return self.count - - def get_state(self) -> Dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "count": self.count} - - @classmethod - def merge_states(cls, states: List[Dict[str, Any]]) -> int: - if not states: - return 0 - return sum(s["count"] for s in states) - - def reset(self) -> None: - self.count = 0 - - -def create_accumulator(reduction: ReductionType) -> MetricAccumulator: - acc_class = reduction.accumulator_class - return acc_class(reduction) - - -class MetricCollector: - _instances: Dict[int, "MetricCollector"] = {} - - def __new__(cls): - rank = current_rank().rank - if rank not in cls._instances: - inst = super().__new__(cls) - cls._instances[rank] = inst - inst._singleton_rank = rank - else: - inst = cls._instances[rank] - if inst._singleton_rank != rank: - raise ValueError( - f"Singleton expected rank {inst._singleton_rank}, but saw {rank}" - ) - return inst - - def __init__(self): - if hasattr(self, "_initialized_sync"): - return - self._initialized_sync = True - self.accumulators: Dict[str, MetricAccumulator] = {} - self.backends: List[Backend] = [] - self._initialized_async = False - self.rank = current_rank().rank - print(f"๐Ÿ”ง MetricCollector rank {self.rank}: Singleton initialized (unique)") - - async def _init(self, primary_backend_states: Dict[str, Dict[str, Any]]): - if self._initialized_async: - return - - from monarch.actor import get_or_spawn_controller - - from forge.controller.v3.metric_actors import GlobalLoggingActor - - global_logger = await get_or_spawn_controller( - "global_logger", GlobalLoggingActor - ) - - config = await global_logger.get_metric_logger_cfg.call_one() - if config is None: - raise ValueError(f"Rank {self.rank}: Config not setโ€”call init_config first") - - # Init local backends only if log_per_rank=True, inject primary states - for backend_config in config["backends"]: - if not backend_config.get("log_per_rank", True): - continue # Skip globals/reduce - cls_name = backend_config["class"] - primary_state = primary_backend_states.get(cls_name, {}) - backend = create_backend(backend_config) - await backend.setup(config, role="local", primary_states=primary_state) - self.backends.append(backend) - print(f"๐Ÿ”ง Collector rank {self.rank}: Initialized local backend {cls_name}") - - self._initialized_async = True - print(f"๐Ÿ”ง MetricCollector rank {self.rank}: Async initialization complete") - - async def push( - self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN - ): - # Assume eager init; fallback to lazy - if not self._initialized_async: - raise ValueError("Collector not initializedโ€”call init first") - if key not in self.accumulators: - self.accumulators[key] = create_accumulator(reduction) - self.accumulators[key].append(value) - print( - f"๐Ÿ”ง Collector rank {self.rank}: Pushed {key}={value} (reduction={reduction.value})" - ) - - async def log_and_reset( - self, step: int, return_state: bool = False - ) -> Optional[Dict[str, Dict[str, Any]]]: - """Log to local backends (if any), optionally return states, and reset.""" - if not self._initialized_async: - raise ValueError("Collector not initializedโ€”call init first") - if not self.accumulators: - print(f"๐Ÿ”ง Collector rank {self.rank}: No metrics to flush for step {step}") - return {} if return_state else None - - print( - f"๐Ÿ”ง Collector rank {self.rank}: Accumulators before flush: {list(self.accumulators.keys())}" - ) - - # Snapshot states and reset immediately - states = {key: acc.get_state() for key, acc in self.accumulators.items()} - for acc in list(self.accumulators.values()): - acc.reset() - self.accumulators.clear() - - # Derive metrics from states if needed - if self.backends: - metrics = {} - for key, state in states.items(): - red_enum = ReductionType(state["reduction_type"]) - acc_class = red_enum.accumulator_class - metrics[key] = acc_class.merge_states([state]) - print(f"๐Ÿ”ง Collector rank {self.rank}: Metrics: {metrics}") - - # Log to local backends - for backend in self.backends: - await backend.log(metrics, step) - - if return_state: - print(f"๐Ÿ”ง Collector rank {self.rank}: States: {list(states.keys())}") - - print(f"๐Ÿ”ง Collector rank {self.rank}: Flushed and reset for step {step}") - return states if return_state else None - - async def shutdown(self): - """Shutdown backends if initialized.""" - if not self._initialized_async: - print(f"๐Ÿ”ง Collector rank {self.rank}: Not initialized, skipping shutdown") - return - for backend in self.backends: - await backend.finish() - print(f"๐Ÿ”ง Collector rank {self.rank}: Shutdown complete") diff --git a/src/forge/controller/v3/__init__.py b/src/forge/observability/__init__.py similarity index 100% rename from src/forge/controller/v3/__init__.py rename to src/forge/observability/__init__.py diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py new file mode 100644 index 000000000..92a584a23 --- /dev/null +++ b/src/forge/observability/metric_actors.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +from typing import Any, Dict + +from monarch.actor import Actor, endpoint + + +logger = logging.getLogger(__name__) + + +class LocalFetcherActor(Actor): + """ + Thin per-process actor to trigger MetricCollector ops without direct access, + used by GlobalLoggingActor to broadcast inits/flushes across ranks. + """ + + @endpoint + async def log_and_reset( + self, step: int, return_state: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. + + Args: + step (int): train step used by backends to align all metrics on the same x-axis + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + If False, returns empty dict, else returns the state of all metrics collected. + Returns: + Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, + e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. + """ + from forge.observability.metrics import MetricCollector + + collector = MetricCollector() + result = await collector.log_and_reset(step, return_state=return_state) + return result + + @endpoint + async def init_collector( + self, metadata_per_primary_backend: Dict[str, Dict[str, Any]] + ): + from forge.observability.metrics import MetricCollector + + collector = MetricCollector() + await collector.init_local_backends(metadata_per_primary_backend) + + @endpoint + async def shutdown(self): + from forge.observability.metrics import MetricCollector + + collector = MetricCollector() + await collector.shutdown() + + +class GlobalLoggingActor(Actor): + def __init__(self): + self.fetchers: Dict[str, LocalFetcherActor] = {} + self.config: Dict[str, Any] | None = None + self.global_logger_backends: Dict[str, "LoggerBackend"] = {} + self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} + + @endpoint + async def init_config(self, config: Dict[str, Any]): + """ + Sets config on global actor and inits backends; broadcasts to registered per-rank fetchers. + + - Validates unique backend classes; + - Extracts metadata from a primary logger to be shared with secondary loggers (e.g., shared run IDs) for per-rank modes. + - Eagerly inits metric collectors on fetchers. + + Args: + config (Dict[str, Any]): Config for metric logging + """ + self.config = config + + # Validate unique classes + classes = [b["class"] for b in config["backends"]] + if len(set(classes)) != len(classes): + raise ValueError("Duplicate logger_backend classes in config") + + # Init global logger_backends and states where needed + from forge.observability.metrics import get_logger_backend_class + + for backend_config in config["backends"]: + cls_name = backend_config["class"] + backend = get_logger_backend_class(cls_name)(backend_config) + await backend.init(role="global") + + # Extract metadata from primary logger to be shared with secondary loggers + # and store it + log_per_rank = backend_config.get("log_per_rank", True) + if log_per_rank: + primary_backend_metadata = ( + backend.get_metadata_for_secondary_ranks() or {} + ) + self.metadata_per_primary_backend[cls_name] = primary_backend_metadata + + # Store global logger backends + if not log_per_rank: + self.global_logger_backends[cls_name] = backend + + # Eager init collectors on all registered fetchers in parallel, passing primary states + if self.fetchers: + tasks = [ + fetcher.init_collector.call(self.metadata_per_primary_backend) + for fetcher in self.fetchers.values() + ] + await asyncio.gather(*tasks, return_exceptions=True) + + @endpoint + def get_config(self) -> Dict[str, Any] | None: + """ + Returns the stored logging config for MetricCollector to query during init, + so they can be initialized in their own process. + """ + return self.config + + @endpoint + async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): + self.fetchers[name] = fetcher + + @endpoint + async def flush_global(self, step: int): + """ + Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors + log to local backends and return states if needed for cross-rank reduction. + + Args: + step (int): Global step for logging. + """ + if not self.fetchers: + return + + config = self.config + has_reduce = any(not b.get("log_per_rank", True) for b in config["backends"]) + + logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + + # Broadcast log_and_reset to all fetchers + results = await asyncio.gather( + *[ + f.log_and_reset.call(step, return_state=has_reduce) + for f in self.fetchers.values() + ], + return_exceptions=True, + ) + + if has_reduce: + # Handle exceptions and extract values from ValueMesh results + all_local_states = [] + for res in results: + if isinstance(res, Exception): + logger.warning(f"Flush failed on a fetcher: {res}") + continue + # res is a ValueMesh. TODO: use public API (.items()), but need to parse metadata + res = res._values + if isinstance(res, list): + all_local_states.extend( + [r for r in res if isinstance(r, dict) and r] + ) + elif isinstance(res, dict) and res: + all_local_states.append(res) + + if not all_local_states: + logger.warning(f"No states to reduce for step {step}") + return + + # Reduce + from forge.observability.metrics import reduce_metrics_states + + reduced_metrics = reduce_metrics_states(all_local_states) + + # Log to each global logger_backend + for ( + logger_backend_name, + logger_backend, + ) in self.global_logger_backends.items(): + await logger_backend.log(reduced_metrics, step) + + @endpoint + async def shutdown(self): + # Finish per-rank logger_backends via fetchers + if self.fetchers: + tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] + await asyncio.gather(*tasks, return_exceptions=True) + # Finish global logger_backends + for logger_backend_name, logger_backend in self.global_logger_backends.items(): + await logger_backend.finish() diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py new file mode 100644 index 000000000..e6632db90 --- /dev/null +++ b/src/forge/observability/metrics.py @@ -0,0 +1,633 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional + +from monarch.actor import context, current_rank + + +logger = logging.getLogger(__name__) + + +# Reduction Types +class ReductionType(Enum): + MEAN = "mean" + SUM = "sum" + MAX = "max" + MIN = "min" + COUNT = "count" + + @property + def accumulator_class(self): + mapping = { + ReductionType.MEAN: MeanAccumulator, + ReductionType.SUM: SumAccumulator, + ReductionType.MAX: MaxAccumulator, + ReductionType.MIN: MinAccumulator, + ReductionType.COUNT: CountAccumulator, + } + return mapping[self] + + +def get_actor_name_with_rank() -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Returns a string like "TrainActor_abcd_r0" (actor type + replica ID suffix + local rank). + Relies on parsing actor_id string; fallback to "UnknownActor" if context unavailable. + + # TODO: Replace string parsing with structured actor_id access once Monarch exposes it. + """ + # Add more defensive checks + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + rank_name = "UnknownActor" # fallback + if len(parts) >= 2: + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Extract world ID and proc rank + world_id = world_part.split("[")[0] if "[" in world_part else world_part + + # Extract clean actor name (remove "Configured" suffix if present) + if "[" in actor_part: + actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" + if actor_name.endswith("Configured"): + actor_name = actor_name[:-10] # Remove "Configured" + else: + actor_name = actor_part + + # Use last 4 characters of world_id as replica identifier + # This is deterministic, readable, and works for any number of replicas + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + # Use current_rank().rank as the local rank within the replica + local_rank = rank.rank + + rank_name = f"{actor_name}_{replica_id}_r{local_rank}" + + return rank_name + + +# Simple push +async def record_metric( + key: str, value: Any, reduction: ReductionType = ReductionType.MEAN +) -> None: + """ + Records a metric value for later reduction and logging. + + Relies on a per-rank MetricCollector singleton for ease of use, i.e. + call `record_metric` anywhere in the code without moving the + collector from function to function. + + The collector methods are triggered per-rank by a + `forge.observability.metric_actors.LocalFetcherActor`, instantiated + during actor initialization. + + Records are flushed after every N train steps, triggered by + `forge.observability.metric_actors.GlobalLoggingActor` + """ + collector = MetricCollector() + await collector.push(key, value, reduction) + + +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: + """Reduce metric accumulators states to a single value per metric. + + Can be used when reducing metrics across ranks or services, as merging + states is more precise than merging locally reduced metrics. + + Args: + states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics, + normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. + + Returns: + Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + + Example: + states = [ + {"loss": {"count": 5, "sum": 14, "reduction_type": ReductionType.MEAN}}, + {"loss": {"count": 10, "sum": 16, "reduction_type": ReductionType.MEAN}}, + ] + reduce_metrics_states(states) + >>> {"loss": 2.0} + + Raises: + ValueError: on mismatched reduction types for the same metric key. + """ + if not states: + return {} + + # Collect unique keys across all + all_keys = set(k for state in states for k in state) + + reduced_metrics = {} + for key in all_keys: + metric_states = [state.get(key) for state in states if key in state] + if not metric_states: + continue + + first_reduction_type = metric_states[0]["reduction_type"] + + # Check consistency + for state in metric_states: + if state["reduction_type"] != first_reduction_type: + raise ValueError( + f"Mismatched reduction types for key '{key}': {first_reduction_type} vs {state['reduction_type']}" + ) + + metric_accumulator = ReductionType(first_reduction_type).accumulator_class + reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) + reduced_metrics[key] = reduced_value + + return reduced_metrics + + +class LoggerBackend(ABC): + """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc. + + #TODO: improve docstrings. Say how they are used/when/why/what they should do. Keep it short + but informative. For example, it should behave differently if logging per rank or reducing. + how global actor can call get_metadata_for_secondary_ranks from the primary run so it can share with the others + during initialize. + """ + + def __init__(self, logger_backend_config: Dict[str, Any]): + self.logger_backend_config = logger_backend_config + + @abstractmethod + async def init( + self, + role: str, + primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initializes backend for role in distributed logging flow. + + Called by GlobalLoggingActor: globals first, then broadcasts metadata to locals via fetchers. + + Args: + role (str): "global" (controller/primary) or "local" (per-rank/secondary). + primary_metadata (Optional[Dict[str, Any]]): From global backend for + backend that required shared info, e.g. {"shared_run_id": "abc123"}. + + Raises: ValueError if missing metadata for shared local init. + """ + if primary_logger_backend_metadata is None: + primary_logger_backend_metadata = {} + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + async def finish(self) -> None: + pass + + def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + return None + + +class ConsoleBackend(LoggerBackend): + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + + async def init( + self, + role: str, + primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + prefix = ( + get_actor_name_with_rank() + if self.logger_backend_config.get("log_per_rank", True) + else "GLOBAL" + ) + logger.info(f"=== {prefix} METRICS STEP {step} ===") + # TODO: make it a proper display. Maybe pprint? + for key, value in metrics.items(): + logger.info(f" {key}: {value}") + logger.info("==============================\n") + + async def finish(self) -> None: + pass + + +class WandbBackend(LoggerBackend): + # TODO: add some doc about the different modes and this reference: docs.wandb.ai/guides/track/log/distributed-training + # we should probably have an assertion that mode is in the lsit of 3 possible options + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + self.project = logger_backend_config["project"] + self.group = logger_backend_config.get("group", "experiment_group") + self.name = None + self.run = None + self.mode = logger_backend_config.get("mode", "wandb_all_log_all") + valid_modes = [ + "wandb_all_log_all", + "wandb_rank_0_log_all", + "wandb_rank_0_reduce_all", + ] + if self.mode not in valid_modes: + raise ValueError( + f"Invalid WandbBackend mode '{self.mode}'. Must be one of {valid_modes}." + ) + + async def init( + self, + role: str, + primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + import wandb + + if primary_logger_backend_metadata is None: + primary_logger_backend_metadata = {} + self.name = ( + get_actor_name_with_rank() if role == "local" else "global_controller" + ) + + if self.mode == "wandb_all_log_all" and role == "local": + self.run = wandb.init( + project=self.project, group=self.group, name=self.name + ) + elif self.mode == "wandb_rank_0_log_all": + if role == "global": + # Primary + settings = wandb.Settings( + mode="shared", x_primary=True, x_label="controller_primary" + ) + self.run = wandb.init( + project=self.project, group=self.group, settings=settings + ) + # TODO: Make metric definitions automatic or configurable via logger_backend config + self.run.define_metric("global_step") + self.run.define_metric("train/loss", step_metric="global_step") + self.run.define_metric("generate/tokens", step_metric="global_step") + elif role == "local": + # Secondary: Use shared_run_id from primary_logger_backend_metadata + shared_id = primary_logger_backend_metadata.get("shared_run_id") + if shared_id is None: + local_rank = current_rank().rank + raise ValueError( + f"Rank {local_rank}: Shared ID required but not provided" + ) + settings = wandb.Settings( + mode="shared", x_primary=False, x_label=self.name + ) + self.run = wandb.init( + id=shared_id, + project=self.project, + group=self.group, + settings=settings, + ) + elif self.mode == "wandb_rank_0_reduce_all" and role == "global": + self.run = wandb.init(project=self.project, group=self.group) + # self.run.define_metric("global_step") + # self.run.define_metric("train/loss", step_metric="global_step") + # self.run.define_metric("generate/tokens", step_metric="global_step") + else: + logger.debug(f"Skipped init for {self.mode} mode and {role} role") + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.run: + log_data = {**metrics, "global_step": step} + self.run.log(log_data) + logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + else: + logger.debug(f"WandbBackend: No run, skipping log for {self.name}") + + def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + if self.run and self.mode == "wandb_rank_0_log_all": + return {"shared_run_id": self.run.id} + return None # {} for others + + async def finish(self) -> None: + if self.run: + self.run.finish() + logger.info(f"WandbBackend {self.name}: Finished run") + + +def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: + """Simple mapping between logger_backend type and its class + + Factory for backend classes from config; returns uninitialized class for role-based init. + """ + if cls_name == "console": + return ConsoleBackend + elif cls_name == "wandb": + return WandbBackend + else: + raise ValueError(f"Unknown logger backend type: {cls_name}") + + +class MetricAccumulator(ABC): + # TODO: add docstring for every method, explaining when/why this is used + def __init__(self, reduction: ReductionType): + self.reduction_type = reduction + + @abstractmethod + def append(self, value: Any) -> None: + """Updates accumulator with new value (e.g., adds to sum and count for MEAN).""" + pass + + @abstractmethod + def get_value(self) -> Any: + """Returns locally reduced value (e.g., sum/count for MEAN).""" + pass + + @abstractmethod + def get_state(self) -> Dict[str, Any]: + """Returns serializable state for cross-rank merge (e.g., {'sum': 10.0, 'count': 5}).""" + pass + + @classmethod + @abstractmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> Any: + """Merges states from multiple ranks into single reduced value (e.g., total_sum/total_count for MEAN).""" + pass + + @abstractmethod + def reset(self) -> None: + """Clears for next accumulation cycle (e.g., sum=0, count=0 for MEAN).""" + pass + + +class MeanAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.count += 1 + + def get_value(self) -> float: + return self.sum / self.count if self.count > 0 else 0.0 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_count = sum(s["count"] for s in states) + return total_sum / total_count if total_count > 0 else 0.0 + + def reset(self) -> None: + self.sum = 0.0 + self.count = 0 + + +class SumAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.total = 0.0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.total += v + + def get_value(self) -> float: + return self.total + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "total": self.total} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return sum(s["total"] for s in states) + + def reset(self) -> None: + self.total = 0.0 + + +class MaxAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.max_val = float("-inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.max_val = max(self.max_val, v) + + def get_value(self) -> float: + return self.max_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return max(s["max_val"] for s in states) + + def reset(self) -> None: + self.max_val = float("-inf") + + +class MinAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.min_val = float("inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.min_val = min(self.min_val, v) + + def get_value(self) -> float: + return self.min_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return min(s["min_val"] for s in states) + + def reset(self) -> None: + self.min_val = float("inf") + + +class CountAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.count = 0 + + def append(self, value: Any) -> None: + self.count += 1 + + def get_value(self) -> int: + return self.count + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "count": self.count} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> int: + if not states: + return 0 + return sum(s["count"] for s in states) + + def reset(self) -> None: + self.count = 0 + + +class MetricCollector: + """ + Per-rank singleton for accumulating, retrieving or flushing metrics to backends. + + - Ensures one instance per process (rank-enforced); actors call record_metric() which delegates here. + - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; + - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also + return non-reduced states for global aggregation. + - Resets accumulators post-flush to avoid leaks across train steps; + """ + + _instances: Dict[int, "MetricCollector"] = {} + + def __new__(cls): + rank = current_rank().rank + if rank not in cls._instances: + inst = super().__new__(cls) + cls._instances[rank] = inst + inst._singleton_rank = rank + else: + inst = cls._instances[rank] + if inst._singleton_rank != rank: + raise ValueError( + f"Singleton expected rank {inst._singleton_rank}, but saw {rank}" + ) + return inst + + def __init__(self): + if hasattr(self, "_initialized_sync"): + return + self._initialized_sync = True + self.accumulators: Dict[str, MetricAccumulator] = {} + self.logger_backends: List[LoggerBackend] = [] + self._initialized_async = False + self.rank = current_rank().rank + + async def init_local_backends( + self, metadata_per_primary_backend: Dict[str, Dict[str, Any]] + ) -> None: + """Initializes collector with logger_backends from global config. + + Queries global logger for config; sets up local logger_backends only if log_per_rank=True. + Called once per-rank by LocalFetcherActor. + """ + if self._initialized_async: + return + + from monarch.actor import get_or_spawn_controller + + from forge.observability.metric_actors import GlobalLoggingActor + + global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + + config = await global_logger.get_config.call_one() + if config is None: + raise ValueError(f"Rank {self.rank}: Config not setโ€”call init_config first") + + # Init local logger_backends only if log_per_rank=True, inject state from + # the primary logger, which may have shared info for all secondary local loggers. + for logger_backend_config in config["backends"]: + if not logger_backend_config.get("log_per_rank", True): + continue # Skip globals/reduce + cls_name = logger_backend_config["class"] + primary_state = metadata_per_primary_backend.get(cls_name, {}) + logger_backend = get_logger_backend_class(cls_name)(logger_backend_config) + await logger_backend.init( + role="local", primary_logger_backend_metadata=primary_state + ) + self.logger_backends.append(logger_backend) + + self._initialized_async = True + + async def push( + self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN + ) -> None: + if not self._initialized_async: + raise ValueError("Collector not initializedโ€”call init first") + + if key not in self.accumulators: + self.accumulators[key] = reduction.accumulator_class(reduction) + + self.accumulators[key].append(value) + + async def log_and_reset( + self, step: int, return_state: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. + + Args: + step (int): train step used by backends to align all metrics on the same x-axis + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + If False, returns empty dict, else returns the state of all metrics collected. + Returns: + Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, + e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. + """ + if not self._initialized_async: + raise ValueError("Collector not initializedโ€”call init first") + + if not self.accumulators: + logger.debug( + f"Collector rank {self.rank}: No metrics to flush for step {step}" + ) + return {} + + # Snapshot states and reset immediately + states = {} + for key, acc in self.accumulators.items(): + states[key] = acc.get_state() + acc.reset() + + # Reduce metrics from states for logging if any per-rank backend + if self.logger_backends: + metrics = {} + for key, state in states.items(): + acc_class = ReductionType(state["reduction_type"]).accumulator_class + metrics[key] = acc_class.get_reduced_value_from_states([state]) + + # Log to local logger_backends + for logger_backend in self.logger_backends: + await logger_backend.log(metrics, step) + + return states if return_state else {} + + async def shutdown(self): + """Shutdown logger_backends if initialized.""" + if not self._initialized_async: + logger.debug( + f"Collector rank {self.rank}: Not initialized, skipping shutdown" + ) + return + + for logger_backend in self.logger_backends: + await logger_backend.finish() From c3370830eca55f00c21f1be147fa473764129817 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 22 Sep 2025 22:42:25 -0700 Subject: [PATCH 05/25] docstring --- apps/toy_metrics/main.py | 2 +- src/forge/observability/metric_actors.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 586d6780d..6165967d0 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -84,7 +84,7 @@ async def main(mode: str = "wandb_all_log_all"): # Now init config on global (inits backends eagerly across fetchers) global_logger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) - await global_logger.init_config.call_one(logging_config) + await global_logger.initialize_backends.call_one(logging_config) for i in range(3): print(f"\n=== Global Step {i} ===") diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 92a584a23..83a0615ee 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -15,9 +15,11 @@ class LocalFetcherActor(Actor): - """ - Thin per-process actor to trigger MetricCollector ops without direct access, - used by GlobalLoggingActor to broadcast inits/flushes across ranks. + """Thin per-process actor used to trigger MetricCollector singleton + operations without direct access. It is what GlobalLoggingActor + uses to broadcast inits/flushes across ranks. + + GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ @endpoint @@ -58,6 +60,11 @@ async def shutdown(self): class GlobalLoggingActor(Actor): + """Coordinates metric logging across all ranks for every training step. + + Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), + and per-rank and global reduction logging modes.""" + def __init__(self): self.fetchers: Dict[str, LocalFetcherActor] = {} self.config: Dict[str, Any] | None = None @@ -65,7 +72,7 @@ def __init__(self): self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} @endpoint - async def init_config(self, config: Dict[str, Any]): + async def initialize_backends(self, config: Dict[str, Any]): """ Sets config on global actor and inits backends; broadcasts to registered per-rank fetchers. @@ -116,7 +123,7 @@ async def init_config(self, config: Dict[str, Any]): def get_config(self) -> Dict[str, Any] | None: """ Returns the stored logging config for MetricCollector to query during init, - so they can be initialized in their own process. + so local backends can be initialized in their own process. """ return self.config From 16f226776470d2cbd5876765a4f3354207c83a88 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 22 Sep 2025 22:47:25 -0700 Subject: [PATCH 06/25] comments --- src/forge/observability/metrics.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index e6632db90..3e9724bc0 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -85,7 +85,6 @@ def get_actor_name_with_rank() -> str: return rank_name -# Simple push async def record_metric( key: str, value: Any, reduction: ReductionType = ReductionType.MEAN ) -> None: @@ -222,7 +221,8 @@ async def log(self, metrics: Dict[str, Any], step: int) -> None: else "GLOBAL" ) logger.info(f"=== {prefix} METRICS STEP {step} ===") - # TODO: make it a proper display. Maybe pprint? + + # TODO: Improve display. Maybe pprint? Currently requires loglevel == info for key, value in metrics.items(): logger.info(f" {key}: {value}") logger.info("==============================\n") @@ -232,8 +232,17 @@ async def finish(self) -> None: class WandbBackend(LoggerBackend): - # TODO: add some doc about the different modes and this reference: docs.wandb.ai/guides/track/log/distributed-training - # we should probably have an assertion that mode is in the lsit of 3 possible options + """Reference: docs.wandb.ai/guides/track/log/distributed-training + + #TODO: give this better names + #TODO: most likely delete wandb_rank_0_log_all + valid_modes = [ + "wandb_all_log_all", # Track multiple processes + "wandb_rank_0_log_all", #Track all processes to a single run + "wandb_rank_0_reduce_all", # Track a single process + ] + """ + def __init__(self, logger_backend_config: Dict[str, Any]): super().__init__(logger_backend_config) self.project = logger_backend_config["project"] From 40e16c28b04f07463c4bbcc0516712361a60533b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 05:50:15 -0700 Subject: [PATCH 07/25] update method name --- apps/toy_metrics/main.py | 2 +- src/forge/observability/metric_actors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 6165967d0..47f3cdee6 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -91,7 +91,7 @@ async def main(mode: str = "wandb_all_log_all"): await trainer.train_step.call(i) for sub in range(3): await generator.generate_step.call(i, sub) - await global_logger.flush_global.call_one(i) + await global_logger.flush.call_one(i) await global_logger.shutdown.call_one() diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 83a0615ee..ace1a3a83 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -132,7 +132,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): self.fetchers[name] = fetcher @endpoint - async def flush_global(self, step: int): + async def flush(self, step: int): """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. From d7c175d64c0418c8cdad2d4a3b266766b69c98a5 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 06:00:37 -0700 Subject: [PATCH 08/25] no circular import --- src/forge/controller/__init__.py | 9 +++------ src/forge/controller/service/replica.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index 23a3d6804..8f7c2f420 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .actor import ForgeActor from .proc_mesh import get_proc_mesh, stop_proc_mesh @@ -10,7 +11,7 @@ # service async def spawn_actors( name: str, - actor_cls, + actor_cls: ForgeActor, cfg, processes, set_address: bool = False, @@ -22,8 +23,4 @@ async def spawn_actors( return actors -__all__ = [ - "spawn_actors", - "stop_proc_mesh", - "get_proc_mesh", -] +__all__ = ["spawn_actors", "stop_proc_mesh", "get_proc_mesh", "ForgeActor"] diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 569f0e829..b84e5eec7 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -15,7 +15,7 @@ from monarch.actor import ActorError -from forge.controller.actor import ForgeActor +from forge.controller import ForgeActor from forge.types import ProcessConfig logger = logging.getLogger(__name__) From 538e8f24fc0232e293898b051f0835aa8e97054b Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 06:19:22 -0700 Subject: [PATCH 09/25] update command --- apps/grpo/qwen3_1_7b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 1f90a091c..7b1804850 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -1,5 +1,5 @@ # Grouped Relative Policy Optimization (GRPO) -# >>> python -m apps.grpo.qwen3_1_7b --config apps/grpo/qwen3_1_7b.yaml +# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml # Global configuration group_size: 8 From 166b5d48745df1ac3dd4098ad4f2f4653be8fa0a Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 08:02:37 -0700 Subject: [PATCH 10/25] update arg name --- apps/toy_metrics/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 47f3cdee6..16dfe12a0 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -76,7 +76,7 @@ async def main(mode: str = "wandb_all_log_all"): logging_config = { "backends": backends, } - service_config = {"procs_per_replica": 2, "num_replicas": 2, "with_gpus": False} + service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} # Spawn services first (triggers registrations via provisioner hook) trainer = await TrainActor.options(**service_config).as_service() From e27d4511f15aee74af0574e9b847533379e614a6 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 09:17:46 -0700 Subject: [PATCH 11/25] move metric actor out of asyncio lock --- src/forge/controller/provisioner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 47f2dd34b..a50e0cbb5 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -232,14 +232,14 @@ def bootstrap(gpu_ids: int): procs._host = host_mesh - # Spawn local logging actor on each process and register with global logger - await self._setup_logging(procs) - # If we created a server, track so we can tear it down later. if server_name: self._server_names.append(server_name) self._proc_server_map[procs] = server_name + # Spawn local logging actor on each process and register with global logger + await self._setup_logging(procs) + return procs async def stop_proc_mesh(self, proc_mesh: ProcMesh): From 5cadbee31560c467b8a654142da6a90101826213 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 16:40:59 -0700 Subject: [PATCH 12/25] fix deregister --- apps/toy_metrics/main.py | 3 ++- src/forge/controller/provisioner.py | 6 +++--- src/forge/observability/metric_actors.py | 7 +++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 16dfe12a0..7ce75610a 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -11,6 +11,7 @@ import time from forge.controller.actor import ForgeActor +from forge.controller.provisioner import shutdown from forge.observability.metric_actors import GlobalLoggingActor from forge.observability.metrics import record_metric, ReductionType @@ -93,7 +94,7 @@ async def main(mode: str = "wandb_all_log_all"): await generator.generate_step.call(i, sub) await global_logger.flush.call_one(i) - await global_logger.shutdown.call_one() + await shutdown() if __name__ == "__main__": diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 4be9d7684..f7172fade 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -14,6 +14,8 @@ import uuid import monarch + +from forge.types import ProcessConfig from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import ( @@ -28,8 +30,6 @@ from monarch.tools.components import hyperactor from monarch.tools.config import Config -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -256,7 +256,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): "global_logger", GlobalLoggingActor ) process_name = f"proc_mesh_{id(proc_mesh)}" - await global_logger.deregister.call_one(process_name) + await global_logger.deregister_fetcher.call_one(process_name) if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index ace1a3a83..96e3dad03 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -129,8 +129,15 @@ def get_config(self) -> Dict[str, Any] | None: @endpoint async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): + """Registers a fetcher with the global actor. Each key represents a process mesh. + If there are 2 processes, each with 2 replicas with N gpus, we would + have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" self.fetchers[name] = fetcher + @endpoint + async def deregister_fetcher(self, name: str): + del self.fetchers[name] + @endpoint async def flush(self, step: int): """ From cb33d5f30f5bd15abd69dfb7138c37831d3f967c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 16:47:07 -0700 Subject: [PATCH 13/25] lint --- src/forge/controller/provisioner.py | 4 ++-- src/forge/data_models/episode.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index f7172fade..0bb3e480b 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -14,8 +14,6 @@ import uuid import monarch - -from forge.types import ProcessConfig from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import ( @@ -30,6 +28,8 @@ from monarch.tools.components import hyperactor from monarch.tools.config import Config +from forge.types import ProcessConfig + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/src/forge/data_models/episode.py b/src/forge/data_models/episode.py index 5df2352ab..835373d18 100644 --- a/src/forge/data_models/episode.py +++ b/src/forge/data_models/episode.py @@ -8,6 +8,7 @@ from typing import Optional, Sequence import torch + from forge.data_models.scored_completion import ScoredCompletion From f28097cfa715be70f4351022106536fc5daef333 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 23 Sep 2025 16:56:26 -0700 Subject: [PATCH 14/25] docstring --- src/forge/controller/provisioner.py | 6 +- src/forge/observability/metric_actors.py | 2 +- src/forge/observability/metrics.py | 383 ++++++++++++----------- 3 files changed, 200 insertions(+), 191 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 0bb3e480b..cd99cb9b4 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -150,7 +150,11 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: ) async def _setup_logging(self, procs: ProcMesh) -> None: - """Spawn and register local fetcher for metrics on each process.""" + """Spawn and register local fetcher for metric logging on each process. + When a service is spawned, we create for each rank a LocalFetcherActor and + store it at GlobalLoggingActor. Backends (e.g. wandb) should be eagerly instantiated + later in main by calling `global_logger.initialize_backends.call_one(logging_config)` + """ from forge.observability.metric_actors import ( GlobalLoggingActor, LocalFetcherActor, diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 96e3dad03..18916fe24 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -30,7 +30,7 @@ async def log_and_reset( Args: step (int): train step used by backends to align all metrics on the same x-axis - return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3e9724bc0..98cd10a63 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -158,195 +158,6 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, return reduced_metrics -class LoggerBackend(ABC): - """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc. - - #TODO: improve docstrings. Say how they are used/when/why/what they should do. Keep it short - but informative. For example, it should behave differently if logging per rank or reducing. - how global actor can call get_metadata_for_secondary_ranks from the primary run so it can share with the others - during initialize. - """ - - def __init__(self, logger_backend_config: Dict[str, Any]): - self.logger_backend_config = logger_backend_config - - @abstractmethod - async def init( - self, - role: str, - primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Initializes backend for role in distributed logging flow. - - Called by GlobalLoggingActor: globals first, then broadcasts metadata to locals via fetchers. - - Args: - role (str): "global" (controller/primary) or "local" (per-rank/secondary). - primary_metadata (Optional[Dict[str, Any]]): From global backend for - backend that required shared info, e.g. {"shared_run_id": "abc123"}. - - Raises: ValueError if missing metadata for shared local init. - """ - if primary_logger_backend_metadata is None: - primary_logger_backend_metadata = {} - pass - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - pass - - async def finish(self) -> None: - pass - - def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: - """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" - return None - - -class ConsoleBackend(LoggerBackend): - def __init__(self, logger_backend_config: Dict[str, Any]): - super().__init__(logger_backend_config) - - async def init( - self, - role: str, - primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - pass - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("log_per_rank", True) - else "GLOBAL" - ) - logger.info(f"=== {prefix} METRICS STEP {step} ===") - - # TODO: Improve display. Maybe pprint? Currently requires loglevel == info - for key, value in metrics.items(): - logger.info(f" {key}: {value}") - logger.info("==============================\n") - - async def finish(self) -> None: - pass - - -class WandbBackend(LoggerBackend): - """Reference: docs.wandb.ai/guides/track/log/distributed-training - - #TODO: give this better names - #TODO: most likely delete wandb_rank_0_log_all - valid_modes = [ - "wandb_all_log_all", # Track multiple processes - "wandb_rank_0_log_all", #Track all processes to a single run - "wandb_rank_0_reduce_all", # Track a single process - ] - """ - - def __init__(self, logger_backend_config: Dict[str, Any]): - super().__init__(logger_backend_config) - self.project = logger_backend_config["project"] - self.group = logger_backend_config.get("group", "experiment_group") - self.name = None - self.run = None - self.mode = logger_backend_config.get("mode", "wandb_all_log_all") - valid_modes = [ - "wandb_all_log_all", - "wandb_rank_0_log_all", - "wandb_rank_0_reduce_all", - ] - if self.mode not in valid_modes: - raise ValueError( - f"Invalid WandbBackend mode '{self.mode}'. Must be one of {valid_modes}." - ) - - async def init( - self, - role: str, - primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, - ) -> None: - import wandb - - if primary_logger_backend_metadata is None: - primary_logger_backend_metadata = {} - self.name = ( - get_actor_name_with_rank() if role == "local" else "global_controller" - ) - - if self.mode == "wandb_all_log_all" and role == "local": - self.run = wandb.init( - project=self.project, group=self.group, name=self.name - ) - elif self.mode == "wandb_rank_0_log_all": - if role == "global": - # Primary - settings = wandb.Settings( - mode="shared", x_primary=True, x_label="controller_primary" - ) - self.run = wandb.init( - project=self.project, group=self.group, settings=settings - ) - # TODO: Make metric definitions automatic or configurable via logger_backend config - self.run.define_metric("global_step") - self.run.define_metric("train/loss", step_metric="global_step") - self.run.define_metric("generate/tokens", step_metric="global_step") - elif role == "local": - # Secondary: Use shared_run_id from primary_logger_backend_metadata - shared_id = primary_logger_backend_metadata.get("shared_run_id") - if shared_id is None: - local_rank = current_rank().rank - raise ValueError( - f"Rank {local_rank}: Shared ID required but not provided" - ) - settings = wandb.Settings( - mode="shared", x_primary=False, x_label=self.name - ) - self.run = wandb.init( - id=shared_id, - project=self.project, - group=self.group, - settings=settings, - ) - elif self.mode == "wandb_rank_0_reduce_all" and role == "global": - self.run = wandb.init(project=self.project, group=self.group) - # self.run.define_metric("global_step") - # self.run.define_metric("train/loss", step_metric="global_step") - # self.run.define_metric("generate/tokens", step_metric="global_step") - else: - logger.debug(f"Skipped init for {self.mode} mode and {role} role") - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - if self.run: - log_data = {**metrics, "global_step": step} - self.run.log(log_data) - logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") - else: - logger.debug(f"WandbBackend: No run, skipping log for {self.name}") - - def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: - if self.run and self.mode == "wandb_rank_0_log_all": - return {"shared_run_id": self.run.id} - return None # {} for others - - async def finish(self) -> None: - if self.run: - self.run.finish() - logger.info(f"WandbBackend {self.name}: Finished run") - - -def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: - """Simple mapping between logger_backend type and its class - - Factory for backend classes from config; returns uninitialized class for role-based init. - """ - if cls_name == "console": - return ConsoleBackend - elif cls_name == "wandb": - return WandbBackend - else: - raise ValueError(f"Unknown logger backend type: {cls_name}") - - class MetricAccumulator(ABC): # TODO: add docstring for every method, explaining when/why this is used def __init__(self, reduction: ReductionType): @@ -640,3 +451,197 @@ async def shutdown(self): for logger_backend in self.logger_backends: await logger_backend.finish() + + +########### +# Backends # +########### + + +class LoggerBackend(ABC): + """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc. + + #TODO: improve docstrings. Say how they are used/when/why/what they should do. Keep it short + but informative. For example, it should behave differently if logging per rank or reducing. + how global actor can call get_metadata_for_secondary_ranks from the primary run so it can share with the others + during initialize. + """ + + def __init__(self, logger_backend_config: Dict[str, Any]): + self.logger_backend_config = logger_backend_config + + @abstractmethod + async def init( + self, + role: str, + primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initializes backend for role in distributed logging flow. + + Called by GlobalLoggingActor: globals first, then broadcasts metadata to locals via fetchers. + + Args: + role (str): "global" (controller/primary) or "local" (per-rank/secondary). + primary_metadata (Optional[Dict[str, Any]]): From global backend for + backend that required shared info, e.g. {"shared_run_id": "abc123"}. + + Raises: ValueError if missing metadata for shared local init. + """ + if primary_logger_backend_metadata is None: + primary_logger_backend_metadata = {} + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + async def finish(self) -> None: + pass + + def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + return None + + +class ConsoleBackend(LoggerBackend): + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + + async def init( + self, + role: str, + primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + prefix = ( + get_actor_name_with_rank() + if self.logger_backend_config.get("log_per_rank", True) + else "GLOBAL" + ) + logger.info(f"=== {prefix} METRICS STEP {step} ===") + + # TODO: Improve display. Maybe pprint? Currently requires loglevel == info + for key, value in metrics.items(): + logger.info(f" {key}: {value}") + logger.info("==============================\n") + + async def finish(self) -> None: + pass + + +class WandbBackend(LoggerBackend): + """Reference: docs.wandb.ai/guides/track/log/distributed-training + + #TODO: give this better names + #TODO: most likely delete wandb_rank_0_log_all + valid_modes = [ + "wandb_all_log_all", # Track multiple processes + "wandb_rank_0_log_all", #Track all processes to a single run + "wandb_rank_0_reduce_all", # Track a single process + ] + """ + + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + self.project = logger_backend_config["project"] + self.group = logger_backend_config.get("group", "experiment_group") + self.name = None + self.run = None + self.mode = logger_backend_config.get("mode", "wandb_all_log_all") + valid_modes = [ + "wandb_all_log_all", + "wandb_rank_0_log_all", + "wandb_rank_0_reduce_all", + ] + if self.mode not in valid_modes: + raise ValueError( + f"Invalid WandbBackend mode '{self.mode}'. Must be one of {valid_modes}." + ) + + async def init( + self, + role: str, + primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + import wandb + + if primary_logger_backend_metadata is None: + primary_logger_backend_metadata = {} + self.name = ( + get_actor_name_with_rank() if role == "local" else "global_controller" + ) + + if self.mode == "wandb_all_log_all" and role == "local": + self.run = wandb.init( + project=self.project, group=self.group, name=self.name + ) + elif self.mode == "wandb_rank_0_log_all": + if role == "global": + # Primary + settings = wandb.Settings( + mode="shared", x_primary=True, x_label="controller_primary" + ) + self.run = wandb.init( + project=self.project, group=self.group, settings=settings + ) + # TODO: Make metric definitions automatic or configurable via logger_backend config + self.run.define_metric("global_step") + self.run.define_metric("train/loss", step_metric="global_step") + self.run.define_metric("generate/tokens", step_metric="global_step") + elif role == "local": + # Secondary: Use shared_run_id from primary_logger_backend_metadata + shared_id = primary_logger_backend_metadata.get("shared_run_id") + if shared_id is None: + local_rank = current_rank().rank + raise ValueError( + f"Rank {local_rank}: Shared ID required but not provided" + ) + settings = wandb.Settings( + mode="shared", x_primary=False, x_label=self.name + ) + self.run = wandb.init( + id=shared_id, + project=self.project, + group=self.group, + settings=settings, + ) + elif self.mode == "wandb_rank_0_reduce_all" and role == "global": + self.run = wandb.init(project=self.project, group=self.group) + # self.run.define_metric("global_step") + # self.run.define_metric("train/loss", step_metric="global_step") + # self.run.define_metric("generate/tokens", step_metric="global_step") + else: + logger.debug(f"Skipped init for {self.mode} mode and {role} role") + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.run: + log_data = {**metrics, "global_step": step} + self.run.log(log_data) + logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + else: + logger.debug(f"WandbBackend: No run, skipping log for {self.name}") + + def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + if self.run and self.mode == "wandb_rank_0_log_all": + return {"shared_run_id": self.run.id} + return None # {} for others + + async def finish(self) -> None: + if self.run: + self.run.finish() + logger.info(f"WandbBackend {self.name}: Finished run") + + +def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: + """Simple mapping between logger_backend type and its class + + Factory for backend classes from config; returns uninitialized class for role-based init. + """ + if cls_name == "console": + return ConsoleBackend + elif cls_name == "wandb": + return WandbBackend + else: + raise ValueError(f"Unknown logger backend type: {cls_name}") From 06afbb5b2ad259937261bd409ab84083fbdad942 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 24 Sep 2025 08:29:09 -0700 Subject: [PATCH 15/25] fix result extraction and add logger shutdown --- apps/toy_metrics/main.py | 8 ++++++++ src/forge/observability/metric_actors.py | 23 ++++++++++++----------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 7ce75610a..8231914ad 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -94,6 +94,14 @@ async def main(mode: str = "wandb_all_log_all"): await generator.generate_step.call(i, sub) await global_logger.flush.call_one(i) + # shutdown + await asyncio.gather( + trainer.shutdown(), + generator.shutdown(), + ) + + await global_logger.shutdown.call_one() + await shutdown() diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 18916fe24..e91e8847e 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -167,18 +167,19 @@ async def flush(self, step: int): if has_reduce: # Handle exceptions and extract values from ValueMesh results all_local_states = [] - for res in results: - if isinstance(res, Exception): - logger.warning(f"Flush failed on a fetcher: {res}") + for result in results: + if isinstance(result, Exception): + logger.warning(f"Flush failed on a fetcher: {result}") continue - # res is a ValueMesh. TODO: use public API (.items()), but need to parse metadata - res = res._values - if isinstance(res, list): - all_local_states.extend( - [r for r in res if isinstance(r, dict) and r] - ) - elif isinstance(res, dict) and res: - all_local_states.append(res) + + # result is a generator that outputs {{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) if not all_local_states: logger.warning(f"No states to reduce for step {step}") From 53699394dc5ca14763135127c9c3d60f3fa88669 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 24 Sep 2025 09:36:45 -0700 Subject: [PATCH 16/25] fix shutdown order --- apps/toy_metrics/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 8231914ad..46b7ab552 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -95,13 +95,13 @@ async def main(mode: str = "wandb_all_log_all"): await global_logger.flush.call_one(i) # shutdown + await asyncio.gather(global_logger.shutdown.call_one()) + await asyncio.gather( trainer.shutdown(), generator.shutdown(), ) - await global_logger.shutdown.call_one() - await shutdown() From ffe09b9f0659476e8247a3b204f3b152e079332d Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 07:07:34 -0700 Subject: [PATCH 17/25] simplification + docstrings --- apps/toy_metrics/main.py | 78 ++--- src/forge/observability/metric_actors.py | 110 ++++--- src/forge/observability/metrics.py | 364 +++++++++++++---------- 3 files changed, 294 insertions(+), 258 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 46b7ab552..d373de565 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -7,7 +7,6 @@ import asyncio import logging -import sys import time from forge.controller.actor import ForgeActor @@ -21,62 +20,44 @@ class TrainActor(ForgeActor): + """Example training actor that records loss metrics.""" + @endpoint async def train_step(self, step: int): rank = current_rank().rank value = rank * 1000 + 100 * step - print(f"๐Ÿ”ง Train rank {rank}: Step {step}, loss={value}") - await record_metric("train/loss", value) + print(f"[TRAIN] Rank {rank}: Step {step}, loss={value}") + record_metric("train/loss", value) class GeneratorActor(ForgeActor): + """Example generation actor that records token count metrics.""" + @endpoint async def generate_step(self, step: int, substep: int): rank = current_rank().rank value = rank * 1000 + step * 100 + substep * 10 - print(f"๐ŸŽฏ Gen rank {rank}: Step {step}.{substep}, tokens={value}") - await record_metric("generate/tokens", value, ReductionType.SUM) + print(f"[GEN] Rank {rank}: Step {step}.{substep}, tokens={value}") + record_metric("generate/tokens", value, ReductionType.SUM) # Main -async def main(mode: str = "wandb_all_log_all"): - group = f"experiment_group_{int(time.time())}" - if mode == "wandb_all_log_all": - backends = [ - {"class": "console", "log_per_rank": True}, - { - "class": "wandb", - "project": "my_project", - "group": group, - "mode": "wandb_all_log_all", - "log_per_rank": True, - }, - ] - elif mode == "wandb_rank_0_reduce_all": - backends = [ - {"class": "console", "log_per_rank": False}, - { - "class": "wandb", - "project": "my_project", - "group": group, - "mode": "wandb_rank_0_reduce_all", - "log_per_rank": False, - }, - ] - else: # wandb_rank_0_log_all - backends = [ - { - "class": "wandb", - "project": "my_project", - "group": group, - "mode": "wandb_rank_0_log_all", - "log_per_rank": True, - }, - ] - - logging_config = { - "backends": backends, +async def main(): + """Example demonstrating distributed metric logging with different backends.""" + group = f"grpo_exp_{int(time.time())}" + + # Config format: {backend_name: backend_config_dict} + # Each backend can specify log_per_rank to control distributed logging behavior + config = { + "console": {"log_per_rank": False}, + "wandb": { + "project": "my_project", + "group": group, + "mode": "wandb_rank_0_reduce_all", + "log_per_rank": False, + }, } + service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} # Spawn services first (triggers registrations via provisioner hook) @@ -85,7 +66,7 @@ async def main(mode: str = "wandb_all_log_all"): # Now init config on global (inits backends eagerly across fetchers) global_logger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) - await global_logger.initialize_backends.call_one(logging_config) + await global_logger.init_backends.call_one(config) for i in range(3): print(f"\n=== Global Step {i} ===") @@ -106,13 +87,4 @@ async def main(mode: str = "wandb_all_log_all"): if __name__ == "__main__": - mode = sys.argv[1] if len(sys.argv) > 1 else "wandb_all_log_all" - valid_modes = [ - "wandb_all_log_all", - "wandb_rank_0_log_all", - "wandb_rank_0_reduce_all", - ] - if mode not in valid_modes: - print(f"Invalid mode: {mode}. Use {valid_modes}") - sys.exit(1) - asyncio.run(main(mode)) + asyncio.run(main()) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e91e8847e..9224cd007 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -6,7 +6,7 @@ import asyncio import logging -from typing import Any, Dict +from typing import Any, Dict, Optional from monarch.actor import Actor, endpoint @@ -22,8 +22,11 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ + def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: + self.global_logger = global_logger + @endpoint - async def log_and_reset( + async def flush( self, step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. @@ -39,17 +42,20 @@ async def log_and_reset( from forge.observability.metrics import MetricCollector collector = MetricCollector() - result = await collector.log_and_reset(step, return_state=return_state) + result = await collector.flush(step, return_state=return_state) return result @endpoint - async def init_collector( - self, metadata_per_primary_backend: Dict[str, Dict[str, Any]] + async def init_backends( + self, + metadata_per_primary_backend: Dict[str, Dict[str, Any]], + config: Dict[str, Any], ): + """Init local (per-rank) logger backends and MetricCollector.""" from forge.observability.metrics import MetricCollector collector = MetricCollector() - await collector.init_local_backends(metadata_per_primary_backend) + await collector.init_backends(metadata_per_primary_backend, config) @endpoint async def shutdown(self): @@ -63,7 +69,20 @@ class GlobalLoggingActor(Actor): """Coordinates metric logging across all ranks for every training step. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), - and per-rank and global reduction logging modes.""" + for per-rank and/or global reduction logging modes. + + If a backend config has flag `reduce_across_ranks=False`, an instance of the backend + is initialized per-rank, otherwise it is done once globally. + + This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor + is automatically spawned per-rank in `forge.controller.provisioner.py` and registered + with this actor. The LocalFetcherActor is responsible for instantiating + the per-rank MetricCollector. + + In summary, the flow is: + - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector + - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + """ def __init__(self): self.fetchers: Dict[str, LocalFetcherActor] = {} @@ -72,61 +91,58 @@ def __init__(self): self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} @endpoint - async def initialize_backends(self, config: Dict[str, Any]): + async def init_backends(self, config: Dict[str, Any]): """ - Sets config on global actor and inits backends; broadcasts to registered per-rank fetchers. + Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors + in all registered fetchers. + + A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source + for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. - - Validates unique backend classes; - - Extracts metadata from a primary logger to be shared with secondary loggers (e.g., shared run IDs) for per-rank modes. - - Eagerly inits metric collectors on fetchers. + The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, + a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, + and will log independently, i.e. each rank will have its own run in wandb. + + Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states + and reduce them to a single value, which will be logged by the primary backend in this controller. Args: - config (Dict[str, Any]): Config for metric logging + config (Dict[str, Any]): Config for metric logging where keys are backend names, + e.g. {"console": {"log_per_rank": True}, "wandb": {"log_per_rank": False}} """ - self.config = config - - # Validate unique classes - classes = [b["class"] for b in config["backends"]] - if len(set(classes)) != len(classes): - raise ValueError("Duplicate logger_backend classes in config") # Init global logger_backends and states where needed from forge.observability.metrics import get_logger_backend_class - for backend_config in config["backends"]: - cls_name = backend_config["class"] - backend = get_logger_backend_class(cls_name)(backend_config) + for backend_name, backend_config in config.items(): + backend = get_logger_backend_class(backend_name)(backend_config) await backend.init(role="global") # Extract metadata from primary logger to be shared with secondary loggers # and store it - log_per_rank = backend_config.get("log_per_rank", True) - if log_per_rank: + reduce_across_ranks = backend_config.get("reduce_across_ranks", True) + if not reduce_across_ranks: primary_backend_metadata = ( backend.get_metadata_for_secondary_ranks() or {} ) - self.metadata_per_primary_backend[cls_name] = primary_backend_metadata + self.metadata_per_primary_backend[ + backend_name + ] = primary_backend_metadata # Store global logger backends - if not log_per_rank: - self.global_logger_backends[cls_name] = backend + if reduce_across_ranks: + self.global_logger_backends[backend_name] = backend - # Eager init collectors on all registered fetchers in parallel, passing primary states + # Eager init collectors on all registered fetchers in parallel, passing primary states and config if self.fetchers: tasks = [ - fetcher.init_collector.call(self.metadata_per_primary_backend) + fetcher.init_backends.call( + self.metadata_per_primary_backend, self.config + ) for fetcher in self.fetchers.values() ] await asyncio.gather(*tasks, return_exceptions=True) - @endpoint - def get_config(self) -> Dict[str, Any] | None: - """ - Returns the stored logging config for MetricCollector to query during init, - so local backends can be initialized in their own process. - """ - return self.config - @endpoint async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): """Registers a fetcher with the global actor. Each key represents a process mesh. @@ -136,6 +152,11 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): @endpoint async def deregister_fetcher(self, name: str): + if name not in self.fetchers: + logger.warning( + f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." + ) + return del self.fetchers[name] @endpoint @@ -151,20 +172,25 @@ async def flush(self, step: int): return config = self.config - has_reduce = any(not b.get("log_per_rank", True) for b in config["backends"]) + # if reduce_across_ranks=True, we need to reduce the states from all ranks + # and log with the primary backend + requires_reduce = any( + backend_config.get("reduce_across_ranks", True) + for backend_config in config.values() + ) logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") - # Broadcast log_and_reset to all fetchers + # Broadcast flush to all fetchers results = await asyncio.gather( *[ - f.log_and_reset.call(step, return_state=has_reduce) + f.flush.call(step, return_state=requires_reduce) for f in self.fetchers.values() ], return_exceptions=True, ) - if has_reduce: + if requires_reduce: # Handle exceptions and extract values from ValueMesh results all_local_states = [] for result in results: @@ -172,7 +198,7 @@ async def flush(self, step: int): logger.warning(f"Flush failed on a fetcher: {result}") continue - # result is a generator that outputs {{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] for gpu_info, local_metric_state in result.items(): if isinstance(local_metric_state, dict): all_local_states.append(local_metric_state) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 98cd10a63..cfe8f26a3 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -15,13 +15,12 @@ logger = logging.getLogger(__name__) -# Reduction Types class ReductionType(Enum): MEAN = "mean" SUM = "sum" MAX = "max" MIN = "min" - COUNT = "count" + STD = "std" @property def accumulator_class(self): @@ -30,7 +29,7 @@ def accumulator_class(self): ReductionType.SUM: SumAccumulator, ReductionType.MAX: MaxAccumulator, ReductionType.MIN: MinAccumulator, - ReductionType.COUNT: CountAccumulator, + ReductionType.STD: StdAccumulator, } return mapping[self] @@ -39,10 +38,9 @@ def get_actor_name_with_rank() -> str: """ Extracts actor information from Monarch context to form a logging name. - Returns a string like "TrainActor_abcd_r0" (actor type + replica ID suffix + local rank). - Relies on parsing actor_id string; fallback to "UnknownActor" if context unavailable. - - # TODO: Replace string parsing with structured actor_id access once Monarch exposes it. + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. """ # Add more defensive checks ctx = context() @@ -85,7 +83,7 @@ def get_actor_name_with_rank() -> str: return rank_name -async def record_metric( +def record_metric( key: str, value: Any, reduction: ReductionType = ReductionType.MEAN ) -> None: """ @@ -99,11 +97,11 @@ async def record_metric( `forge.observability.metric_actors.LocalFetcherActor`, instantiated during actor initialization. - Records are flushed after every N train steps, triggered by - `forge.observability.metric_actors.GlobalLoggingActor` + Records are flushed when `forge.observability.metric_actors.GlobalLoggingActor.flush()` + is called, typically triggered by the training loop at regular intervals. """ collector = MetricCollector() - await collector.push(key, value, reduction) + collector.push(key, value, reduction) def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: @@ -158,8 +156,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, return reduced_metrics +################ +# Accumulators # +################ + + class MetricAccumulator(ABC): - # TODO: add docstring for every method, explaining when/why this is used + """Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them.""" + def __init__(self, reduction: ReductionType): self.reduction_type = reduction @@ -291,44 +295,29 @@ def reset(self) -> None: self.min_val = float("inf") -class CountAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.count = 0 - - def append(self, value: Any) -> None: - self.count += 1 - - def get_value(self) -> int: - return self.count - - def get_state(self) -> Dict[str, Any]: - return {"reduction_type": self.reduction_type.value, "count": self.count} - - @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> int: - if not states: - return 0 - return sum(s["count"] for s in states) - - def reset(self) -> None: - self.count = 0 +############# +# Collector # +############# class MetricCollector: - """ - Per-rank singleton for accumulating, retrieving or flushing metrics to backends. + """Per-rank singleton for accumulating, retrieving and flushing metrics to backends. + + A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, + the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally, + in the GlobalLoggingActor. - - Ensures one instance per process (rank-enforced); actors call record_metric() which delegates here. + - Ensures one instance per process; actors call record_metric() which delegates here. - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also - return non-reduced states for global aggregation. + return non-reduced states for global aggregation. This can be different for each backend. - Resets accumulators post-flush to avoid leaks across train steps; """ _instances: Dict[int, "MetricCollector"] = {} def __new__(cls): + """Singleton per-rank, ensures one instance per process.""" rank = current_rank().rank if rank not in cls._instances: inst = super().__new__(cls) @@ -345,68 +334,57 @@ def __new__(cls): def __init__(self): if hasattr(self, "_initialized_sync"): return - self._initialized_sync = True self.accumulators: Dict[str, MetricAccumulator] = {} - self.logger_backends: List[LoggerBackend] = [] self._initialized_async = False self.rank = current_rank().rank + self.logger_backends: List[LoggerBackend] = [] - async def init_local_backends( - self, metadata_per_primary_backend: Dict[str, Dict[str, Any]] + async def init_backends( + self, + metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]], + config: Dict[str, Any], ) -> None: - """Initializes collector with logger_backends from global config. + """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, + the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated + once globally. - Queries global logger for config; sets up local logger_backends only if log_per_rank=True. - Called once per-rank by LocalFetcherActor. + Args: + metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary + logger backend, e.g., {"wandb": {"run_id": "abc123"}}. + config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. """ if self._initialized_async: + logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return - from monarch.actor import get_or_spawn_controller - - from forge.observability.metric_actors import GlobalLoggingActor - - global_logger = await get_or_spawn_controller( - "global_logger", GlobalLoggingActor - ) - - config = await global_logger.get_config.call_one() - if config is None: - raise ValueError(f"Rank {self.rank}: Config not setโ€”call init_config first") - - # Init local logger_backends only if log_per_rank=True, inject state from - # the primary logger, which may have shared info for all secondary local loggers. - for logger_backend_config in config["backends"]: - if not logger_backend_config.get("log_per_rank", True): - continue # Skip globals/reduce - cls_name = logger_backend_config["class"] - primary_state = metadata_per_primary_backend.get(cls_name, {}) - logger_backend = get_logger_backend_class(cls_name)(logger_backend_config) + # instantiate local backends if any + for backend_name, backend_config in config.items(): + if backend_config.get("reduce_across_ranks", True): + continue # Skip local backend instantiation and use global instead + primary_state = metadata_per_primary_backend.get(backend_name, {}) + logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role="local", primary_logger_backend_metadata=primary_state + role="local", primary_logger_metadata=primary_state ) self.logger_backends.append(logger_backend) self._initialized_async = True - async def push( + def push( self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN ) -> None: - if not self._initialized_async: - raise ValueError("Collector not initializedโ€”call init first") - if key not in self.accumulators: self.accumulators[key] = reduction.accumulator_class(reduction) self.accumulators[key].append(value) - async def log_and_reset( + async def flush( self, step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. Args: - step (int): train step used by backends to align all metrics on the same x-axis + step (int): Step used by backends to align metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -414,11 +392,14 @@ async def log_and_reset( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._initialized_async: - raise ValueError("Collector not initializedโ€”call init first") + logger.debug( + f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + ) + return {} if not self.accumulators: logger.debug( - f"Collector rank {self.rank}: No metrics to flush for step {step}" + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" ) return {} @@ -445,7 +426,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._initialized_async: logger.debug( - f"Collector rank {self.rank}: Not initialized, skipping shutdown" + f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" ) return @@ -459,13 +440,7 @@ async def shutdown(self): class LoggerBackend(ABC): - """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc. - - #TODO: improve docstrings. Say how they are used/when/why/what they should do. Keep it short - but informative. For example, it should behave differently if logging per rank or reducing. - how global actor can call get_metadata_for_secondary_ranks from the primary run so it can share with the others - during initialize. - """ + """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" def __init__(self, logger_backend_config: Dict[str, Any]): self.logger_backend_config = logger_backend_config @@ -474,22 +449,21 @@ def __init__(self, logger_backend_config: Dict[str, Any]): async def init( self, role: str, - primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: """ - Initializes backend for role in distributed logging flow. - - Called by GlobalLoggingActor: globals first, then broadcasts metadata to locals via fetchers. + Initializes backend, e.g. wandb.run.init(). Args: role (str): "global" (controller/primary) or "local" (per-rank/secondary). - primary_metadata (Optional[Dict[str, Any]]): From global backend for + Can be used to behave differently for primary vs secondary roles. + primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. Raises: ValueError if missing metadata for shared local init. """ - if primary_logger_backend_metadata is None: - primary_logger_backend_metadata = {} + if primary_logger_metadata is None: + primary_logger_metadata = {} pass async def log(self, metrics: Dict[str, Any], step: int) -> None: @@ -504,25 +478,24 @@ def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: class ConsoleBackend(LoggerBackend): + """Simple console logging of metrics.""" + def __init__(self, logger_backend_config: Dict[str, Any]): super().__init__(logger_backend_config) async def init( self, role: str, - primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: - pass - - async def log(self, metrics: Dict[str, Any], step: int) -> None: - prefix = ( + self.prefix = ( get_actor_name_with_rank() if self.logger_backend_config.get("log_per_rank", True) else "GLOBAL" ) - logger.info(f"=== {prefix} METRICS STEP {step} ===") - # TODO: Improve display. Maybe pprint? Currently requires loglevel == info + async def log(self, metrics: Dict[str, Any], step: int) -> None: + logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") for key, value in metrics.items(): logger.info(f" {key}: {value}") logger.info("==============================\n") @@ -532,15 +505,21 @@ async def finish(self) -> None: class WandbBackend(LoggerBackend): - """Reference: docs.wandb.ai/guides/track/log/distributed-training - - #TODO: give this better names - #TODO: most likely delete wandb_rank_0_log_all - valid_modes = [ - "wandb_all_log_all", # Track multiple processes - "wandb_rank_0_log_all", #Track all processes to a single run - "wandb_rank_0_reduce_all", # Track a single process - ] + """ + Weights & Biases logging backend for distributed training. + + Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: + Track a single process: reduce_across_ranks=True + Track each process separately: reduce_across_ranks=False, share_run_id=False + Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + + Configuration: + reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). + If False, enables per-rank logging; then use share_run_id to pick mode. + share_run_id (bool, default False): Only used if reduce_across_ranks=False. + True -> shared run across ranks; False -> separate runs per rank. + project (str): WandB project name + group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ def __init__(self, logger_backend_config: Dict[str, Any]): @@ -549,71 +528,81 @@ def __init__(self, logger_backend_config: Dict[str, Any]): self.group = logger_backend_config.get("group", "experiment_group") self.name = None self.run = None - self.mode = logger_backend_config.get("mode", "wandb_all_log_all") - valid_modes = [ - "wandb_all_log_all", - "wandb_rank_0_log_all", - "wandb_rank_0_reduce_all", - ] - if self.mode not in valid_modes: - raise ValueError( - f"Invalid WandbBackend mode '{self.mode}'. Must be one of {valid_modes}." - ) + self.reduce_across_ranks = logger_backend_config.get( + "reduce_across_ranks", True + ) + self.share_run_id = logger_backend_config.get("share_run_id", False) async def init( self, role: str, - primary_logger_backend_metadata: Optional[Dict[str, Any]] = None, + primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: - import wandb - if primary_logger_backend_metadata is None: - primary_logger_backend_metadata = {} + if primary_logger_metadata is None: + primary_logger_metadata = {} + + if role not in ["global", "local"]: + raise ValueError( + f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." + ) + self.name = ( get_actor_name_with_rank() if role == "local" else "global_controller" ) - if self.mode == "wandb_all_log_all" and role == "local": - self.run = wandb.init( - project=self.project, group=self.group, name=self.name - ) - elif self.mode == "wandb_rank_0_log_all": - if role == "global": - # Primary - settings = wandb.Settings( - mode="shared", x_primary=True, x_label="controller_primary" + # Default global mode: only inits on controller + if self.reduce_across_ranks: + if role != "global": + logger.debug( + f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." ) - self.run = wandb.init( - project=self.project, group=self.group, settings=settings - ) - # TODO: Make metric definitions automatic or configurable via logger_backend config - self.run.define_metric("global_step") - self.run.define_metric("train/loss", step_metric="global_step") - self.run.define_metric("generate/tokens", step_metric="global_step") - elif role == "local": - # Secondary: Use shared_run_id from primary_logger_backend_metadata - shared_id = primary_logger_backend_metadata.get("shared_run_id") - if shared_id is None: - local_rank = current_rank().rank - raise ValueError( - f"Rank {local_rank}: Shared ID required but not provided" - ) - settings = wandb.Settings( - mode="shared", x_primary=False, x_label=self.name - ) - self.run = wandb.init( - id=shared_id, - project=self.project, - group=self.group, - settings=settings, - ) - elif self.mode == "wandb_rank_0_reduce_all" and role == "global": - self.run = wandb.init(project=self.project, group=self.group) - # self.run.define_metric("global_step") - # self.run.define_metric("train/loss", step_metric="global_step") - # self.run.define_metric("generate/tokens", step_metric="global_step") - else: - logger.debug(f"Skipped init for {self.mode} mode and {role} role") + return + await self._init_global() + + # Per-rank modes based on share_run_id bool + elif role == "global" and self.share_run_id: + await self._init_shared_global() + + elif role == "local": + if self.share_run_id: + await self._init_shared_local(primary_logger_metadata) + else: + await self._init_per_rank() + + async def _init_global(self): + import wandb + + self.run = wandb.init(project=self.project, group=self.group) + + async def _init_per_rank(self): + import wandb + + self.run = wandb.init(project=self.project, group=self.group, name=self.name) + + async def _init_shared_global(self): + import wandb + + settings = wandb.Settings( + mode="shared", x_primary=True, x_label="controller_primary" + ) + self.run = wandb.init(project=self.project, group=self.group, settings=settings) + + async def _init_shared_local(self, primary_metadata: Dict[str, Any]): + import wandb + + shared_id = primary_metadata.get("shared_run_id") + if shared_id is None: + raise ValueError( + f"Shared ID required but not provided for {self.name} backend init" + ) + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) + self.run = wandb.init( + id=shared_id, + project=self.project, + group=self.group, + settings=settings, + ) async def log(self, metrics: Dict[str, Any], step: int) -> None: if self.run: @@ -621,12 +610,12 @@ async def log(self, metrics: Dict[str, Any], step: int) -> None: self.run.log(log_data) logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") else: - logger.debug(f"WandbBackend: No run, skipping log for {self.name}") + logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") - def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: - if self.run and self.mode == "wandb_rank_0_log_all": + def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]: + if self.run and not self.reduce_across_ranks and self.share_run_id: return {"shared_run_id": self.run.id} - return None # {} for others + return {} async def finish(self) -> None: if self.run: @@ -634,6 +623,55 @@ async def finish(self) -> None: logger.info(f"WandbBackend {self.name}: Finished run") +class StdAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.sum_sq += v * v + self.count += 1 + + def get_value(self) -> float: + if self.count == 0: + return 0.0 + if self.count == 1: + return 0.0 + mean = self.sum / self.count + variance = (self.sum_sq / self.count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "sum_sq": self.sum_sq, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_sum_sq = sum(s["sum_sq"] for s in states) + total_count = sum(s["count"] for s in states) + if total_count == 0: + return 0.0 + if total_count == 1: + return 0.0 + mean = total_sum / total_count + variance = (total_sum_sq / total_count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def reset(self) -> None: + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: """Simple mapping between logger_backend type and its class From 185504dd38e520cd3e5b0ef2c8e3a61221a1a969 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 08:53:39 -0700 Subject: [PATCH 18/25] bug fix + register if respawn --- apps/toy_metrics/main.py | 5 +- src/forge/observability/metric_actors.py | 12 +++ src/forge/observability/metrics.py | 118 ++++++++++++----------- 3 files changed, 76 insertions(+), 59 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index d373de565..1d346e752 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -16,7 +16,7 @@ from monarch.actor import current_rank, endpoint, get_or_spawn_controller -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) class TrainActor(ForgeActor): @@ -53,8 +53,7 @@ async def main(): "wandb": { "project": "my_project", "group": group, - "mode": "wandb_rank_0_reduce_all", - "log_per_rank": False, + "reduce_across_ranks": True, }, } diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 9224cd007..febc5d043 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -110,6 +110,7 @@ async def init_backends(self, config: Dict[str, Any]): config (Dict[str, Any]): Config for metric logging where keys are backend names, e.g. {"console": {"log_per_rank": True}, "wandb": {"log_per_rank": False}} """ + self.config = config # Init global logger_backends and states where needed from forge.observability.metrics import get_logger_backend_class @@ -150,6 +151,13 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" self.fetchers[name] = fetcher + # Self-init for respawned actors + if self.config: + logger.debug(f"Initializing new LocalFetchActor {name}") + await fetcher.init_backends.call( + self.metadata_per_primary_backend, self.config + ) + @endpoint async def deregister_fetcher(self, name: str): if name not in self.fetchers: @@ -223,6 +231,10 @@ async def flush(self, step: int): ) in self.global_logger_backends.items(): await logger_backend.log(reduced_metrics, step) + @endpoint + def get_fetcher_count(self) -> int: + return len(self.fetchers) + @endpoint async def shutdown(self): # Finish per-rank logger_backends via fetchers diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index cfe8f26a3..65f2b01c1 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -295,6 +295,55 @@ def reset(self) -> None: self.min_val = float("inf") +class StdAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.sum_sq += v * v + self.count += 1 + + def get_value(self) -> float: + if self.count == 0: + return 0.0 + if self.count == 1: + return 0.0 + mean = self.sum / self.count + variance = (self.sum_sq / self.count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "sum_sq": self.sum_sq, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_sum_sq = sum(s["sum_sq"] for s in states) + total_count = sum(s["count"] for s in states) + if total_count == 0: + return 0.0 + if total_count == 1: + return 0.0 + mean = total_sum / total_count + variance = (total_sum_sq / total_count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def reset(self) -> None: + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + ############# # Collector # ############# @@ -319,6 +368,7 @@ class MetricCollector: def __new__(cls): """Singleton per-rank, ensures one instance per process.""" rank = current_rank().rank + if rank not in cls._instances: inst = super().__new__(cls) cls._instances[rank] = inst @@ -332,12 +382,14 @@ def __new__(cls): return inst def __init__(self): - if hasattr(self, "_initialized_sync"): + if hasattr(self, "_is_initialized"): return + self.accumulators: Dict[str, MetricAccumulator] = {} - self._initialized_async = False + self._is_initialized = False self.rank = current_rank().rank self.logger_backends: List[LoggerBackend] = [] + self._is_initialized = True async def init_backends( self, @@ -353,7 +405,7 @@ async def init_backends( logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. """ - if self._initialized_async: + if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return @@ -368,11 +420,14 @@ async def init_backends( ) self.logger_backends.append(logger_backend) - self._initialized_async = True + self._is_initialized = True def push( self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN ) -> None: + if not self._is_initialized: + raise ValueError("Collector not initializedโ€”call init first") + if key not in self.accumulators: self.accumulators[key] = reduction.accumulator_class(reduction) @@ -388,10 +443,10 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, + Dict[str, Dict[str, Dict[str, Any]]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ - if not self._initialized_async: + if not self._is_initialized: logger.debug( f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." ) @@ -424,7 +479,7 @@ async def flush( async def shutdown(self): """Shutdown logger_backends if initialized.""" - if not self._initialized_async: + if not self._is_initialized: logger.debug( f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" ) @@ -623,55 +678,6 @@ async def finish(self) -> None: logger.info(f"WandbBackend {self.name}: Finished run") -class StdAccumulator(MetricAccumulator): - def __init__(self, reduction: ReductionType): - super().__init__(reduction) - self.sum = 0.0 - self.sum_sq = 0.0 - self.count = 0 - - def append(self, value: Any) -> None: - v = float(value.item() if hasattr(value, "item") else value) - self.sum += v - self.sum_sq += v * v - self.count += 1 - - def get_value(self) -> float: - if self.count == 0: - return 0.0 - if self.count == 1: - return 0.0 - mean = self.sum / self.count - variance = (self.sum_sq / self.count) - (mean * mean) - return max(0.0, variance) ** 0.5 - - def get_state(self) -> Dict[str, Any]: - return { - "reduction_type": self.reduction_type.value, - "sum": self.sum, - "sum_sq": self.sum_sq, - "count": self.count, - } - - @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: - total_sum = sum(s["sum"] for s in states) - total_sum_sq = sum(s["sum_sq"] for s in states) - total_count = sum(s["count"] for s in states) - if total_count == 0: - return 0.0 - if total_count == 1: - return 0.0 - mean = total_sum / total_count - variance = (total_sum_sq / total_count) - (mean * mean) - return max(0.0, variance) ** 0.5 - - def reset(self) -> None: - self.sum = 0.0 - self.sum_sq = 0.0 - self.count = 0 - - def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: """Simple mapping between logger_backend type and its class From 052e937519ce8c1ae391918fad9803e7398f0627 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 09:02:47 -0700 Subject: [PATCH 19/25] it works --- apps/toy_metrics/main.py | 6 ++++-- src/forge/observability/metric_actors.py | 2 +- src/forge/observability/metrics.py | 5 ++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 1d346e752..95cea4656 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -47,13 +47,15 @@ async def main(): group = f"grpo_exp_{int(time.time())}" # Config format: {backend_name: backend_config_dict} - # Each backend can specify log_per_rank to control distributed logging behavior + # Each backend can specify reduce_across_ranks to control distributed logging behavior config = { - "console": {"log_per_rank": False}, + "console": {"reduce_across_ranks": True}, "wandb": { "project": "my_project", "group": group, "reduce_across_ranks": True, + # Only useful if NOT reduce_across_ranks. + "share_run_id": False, # Share run ID across ranks -- Not recommended. }, } diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index febc5d043..e7fe246c9 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -108,7 +108,7 @@ async def init_backends(self, config: Dict[str, Any]): Args: config (Dict[str, Any]): Config for metric logging where keys are backend names, - e.g. {"console": {"log_per_rank": True}, "wandb": {"log_per_rank": False}} + e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} """ self.config = config diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 65f2b01c1..2f2b70494 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -386,10 +386,9 @@ def __init__(self): return self.accumulators: Dict[str, MetricAccumulator] = {} - self._is_initialized = False self.rank = current_rank().rank self.logger_backends: List[LoggerBackend] = [] - self._is_initialized = True + self._is_initialized = False async def init_backends( self, @@ -545,7 +544,7 @@ async def init( ) -> None: self.prefix = ( get_actor_name_with_rank() - if self.logger_backend_config.get("log_per_rank", True) + if self.logger_backend_config.get("reduce_across_ranks", True) else "GLOBAL" ) From efb639dcca331c3004d2ccc28b17f89a1b0e3796 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 09:07:29 -0700 Subject: [PATCH 20/25] use procmesh as key --- src/forge/controller/provisioner.py | 6 ++---- src/forge/observability/metric_actors.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 8143ec1d1..40ccbfb0b 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -168,8 +168,7 @@ async def _setup_logging(self, procs: ProcMesh) -> None: global_logger = await get_or_spawn_controller( "global_logger", GlobalLoggingActor ) - process_name = f"proc_mesh_{id(procs)}" - await global_logger.register_fetcher.call_one(local_fetcher_actor, process_name) + await global_logger.register_fetcher.call_one(local_fetcher_actor, procs) async def get_proc_mesh( self, num_procs: int, with_gpus: bool = False, num_hosts: int | None = None @@ -259,8 +258,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): global_logger = await get_or_spawn_controller( "global_logger", GlobalLoggingActor ) - process_name = f"proc_mesh_{id(proc_mesh)}" - await global_logger.deregister_fetcher.call_one(process_name) + await global_logger.deregister_fetcher.call_one(proc_mesh) if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e7fe246c9..4881d55df 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -163,6 +163,7 @@ async def deregister_fetcher(self, name: str): if name not in self.fetchers: logger.warning( f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." + "Available fetchers: {self.fetchers.keys()}" ) return del self.fetchers[name] From 781906d49151799236989c93a08d313391c1a968 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 09:08:41 -0700 Subject: [PATCH 21/25] docstring --- src/forge/controller/provisioner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 40ccbfb0b..e1d3fa613 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -151,7 +151,8 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: async def _setup_logging(self, procs: ProcMesh) -> None: """Spawn and register local fetcher for metric logging on each process. - When a service is spawned, we create for each rank a LocalFetcherActor and + + When any process is created, we create for each rank a LocalFetcherActor and store it at GlobalLoggingActor. Backends (e.g. wandb) should be eagerly instantiated later in main by calling `global_logger.initialize_backends.call_one(logging_config)` """ From f2a9e09e259464e5b1e6d42b85003008eed05dad Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 09:13:07 -0700 Subject: [PATCH 22/25] remove protected imports --- src/forge/controller/provisioner.py | 8 ++------ src/forge/observability/metric_actors.py | 13 +++++-------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index e1d3fa613..eb7458486 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -28,6 +28,8 @@ from monarch.tools.components import hyperactor from monarch.tools.config import Config +from forge.observability.metric_actors import GlobalLoggingActor, LocalFetcherActor + from forge.types import ProcessConfig logger = logging.getLogger(__name__) @@ -156,10 +158,6 @@ async def _setup_logging(self, procs: ProcMesh) -> None: store it at GlobalLoggingActor. Backends (e.g. wandb) should be eagerly instantiated later in main by calling `global_logger.initialize_backends.call_one(logging_config)` """ - from forge.observability.metric_actors import ( - GlobalLoggingActor, - LocalFetcherActor, - ) local_fetcher_actor = await procs.spawn( "local_fetcher_actor", LocalFetcherActor @@ -254,8 +252,6 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async with self._lock: # Deregister local logger from global logger if hasattr(proc_mesh, "_local_fetcher"): - from forge.observability.metric_actors import GlobalLoggingActor - global_logger = await get_or_spawn_controller( "global_logger", GlobalLoggingActor ) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 4881d55df..f9ddd182a 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -10,6 +10,11 @@ from monarch.actor import Actor, endpoint +from forge.observability.metrics import ( + get_logger_backend_class, + MetricCollector, + reduce_metrics_states, +) logger = logging.getLogger(__name__) @@ -39,7 +44,6 @@ async def flush( Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ - from forge.observability.metrics import MetricCollector collector = MetricCollector() result = await collector.flush(step, return_state=return_state) @@ -52,14 +56,12 @@ async def init_backends( config: Dict[str, Any], ): """Init local (per-rank) logger backends and MetricCollector.""" - from forge.observability.metrics import MetricCollector collector = MetricCollector() await collector.init_backends(metadata_per_primary_backend, config) @endpoint async def shutdown(self): - from forge.observability.metrics import MetricCollector collector = MetricCollector() await collector.shutdown() @@ -112,9 +114,6 @@ async def init_backends(self, config: Dict[str, Any]): """ self.config = config - # Init global logger_backends and states where needed - from forge.observability.metrics import get_logger_backend_class - for backend_name, backend_config in config.items(): backend = get_logger_backend_class(backend_name)(backend_config) await backend.init(role="global") @@ -221,8 +220,6 @@ async def flush(self, step: int): return # Reduce - from forge.observability.metrics import reduce_metrics_states - reduced_metrics = reduce_metrics_states(all_local_states) # Log to each global logger_backend From 8e157bd2295775c25bce449ced621c8556688e8e Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 09:44:24 -0700 Subject: [PATCH 23/25] create get_metric_logger --- apps/toy_metrics/main.py | 12 +++--- src/forge/observability/__init__.py | 49 ++++++++++++++++++++++++ src/forge/observability/metric_actors.py | 21 +++++++++- 3 files changed, 74 insertions(+), 8 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 95cea4656..986cad3ca 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -11,10 +11,10 @@ from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown -from forge.observability.metric_actors import GlobalLoggingActor +from forge.observability.metric_actors import get_metric_logger from forge.observability.metrics import record_metric, ReductionType -from monarch.actor import current_rank, endpoint, get_or_spawn_controller +from monarch.actor import current_rank, endpoint logging.basicConfig(level=logging.DEBUG) @@ -66,18 +66,18 @@ async def main(): generator = await GeneratorActor.options(**service_config).as_service() # Now init config on global (inits backends eagerly across fetchers) - global_logger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) - await global_logger.init_backends.call_one(config) + mlogger = await get_metric_logger() + await mlogger.init_backends.call_one(config) for i in range(3): print(f"\n=== Global Step {i} ===") await trainer.train_step.call(i) for sub in range(3): await generator.generate_step.call(i, sub) - await global_logger.flush.call_one(i) + await mlogger.flush.call_one(i) # shutdown - await asyncio.gather(global_logger.shutdown.call_one()) + await mlogger.shutdown.call_one() await asyncio.gather( trainer.shutdown(), diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 2e41cd717..10787a1d0 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -3,3 +3,52 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from .metric_actors import get_metric_logger, GlobalLoggingActor, LocalFetcherActor +from .metrics import ( + ConsoleBackend, + # Utility functions + get_actor_name_with_rank, + get_logger_backend_class, + # Backend classes + LoggerBackend, + MaxAccumulator, + MeanAccumulator, + # Accumulator classes + MetricAccumulator, + MetricCollector, + MinAccumulator, + record_metric, + reduce_metrics_states, + ReductionType, + StdAccumulator, + SumAccumulator, + WandbBackend, +) + +__all__ = [ + # Main API functions + "record_metric", + "reduce_metrics_states", + "get_actor_name_with_rank", + "get_logger_backend_class", + "get_metric_logger", + # Enums + "ReductionType", + # Actor classes + "GlobalLoggingActor", + "LocalFetcherActor", + # Collector + "MetricCollector", + # Backend classes + "LoggerBackend", + "ConsoleBackend", + "WandbBackend", + # Accumulator classes + "MetricAccumulator", + "MeanAccumulator", + "SumAccumulator", + "MaxAccumulator", + "MinAccumulator", + "StdAccumulator", +] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index f9ddd182a..79b1f954b 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -8,7 +8,7 @@ import logging from typing import Any, Dict, Optional -from monarch.actor import Actor, endpoint +from monarch.actor import Actor, endpoint, get_or_spawn_controller from forge.observability.metrics import ( get_logger_backend_class, @@ -18,6 +18,23 @@ logger = logging.getLogger(__name__) +_global_logger = None + + +async def get_metric_logger(): + """Get or spawn the global logging actor. + + Returns: + GlobalLoggingActor: The global logging actor instance. + + Example: + mlogger = await get_global_logger() + # spawn your processes ... + await mlogger.init_backends(config) + """ + _global_logger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) + return _global_logger + class LocalFetcherActor(Actor): """Thin per-process actor used to trigger MetricCollector singleton @@ -152,7 +169,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): # Self-init for respawned actors if self.config: - logger.debug(f"Initializing new LocalFetchActor {name}") + logger.debug(f"Initializing new LocalFetcherActor {name}") await fetcher.init_backends.call( self.metadata_per_primary_backend, self.config ) From 5736c795f33acfda34bad9e878e2a11d6b082d8d Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Fri, 26 Sep 2025 09:48:29 -0700 Subject: [PATCH 24/25] call became fanout --- apps/toy_metrics/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 986cad3ca..9ce4337e9 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -71,9 +71,9 @@ async def main(): for i in range(3): print(f"\n=== Global Step {i} ===") - await trainer.train_step.call(i) + await trainer.train_step.fanout(i) for sub in range(3): - await generator.generate_step.call(i, sub) + await generator.generate_step.fanout(i, sub) await mlogger.flush.call_one(i) # shutdown From a426cd5381e22bb3474b9097793606f600fa9e21 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Sun, 28 Sep 2025 22:31:12 -0700 Subject: [PATCH 25/25] upstream changes --- apps/toy_metrics/main.py | 4 +- src/forge/controller/provisioner.py | 35 +------- src/forge/observability/__init__.py | 4 +- src/forge/observability/metric_actors.py | 102 +++++++++++++++++++---- 4 files changed, 96 insertions(+), 49 deletions(-) diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py index 9ce4337e9..cd542df44 100644 --- a/apps/toy_metrics/main.py +++ b/apps/toy_metrics/main.py @@ -11,7 +11,7 @@ from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown -from forge.observability.metric_actors import get_metric_logger +from forge.observability.metric_actors import setup_metric_logger from forge.observability.metrics import record_metric, ReductionType from monarch.actor import current_rank, endpoint @@ -60,13 +60,13 @@ async def main(): } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} + mlogger = await setup_metric_logger() # Spawn services first (triggers registrations via provisioner hook) trainer = await TrainActor.options(**service_config).as_service() generator = await GeneratorActor.options(**service_config).as_service() # Now init config on global (inits backends eagerly across fetchers) - mlogger = await get_metric_logger() await mlogger.init_backends.call_one(config) for i in range(3): diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index eb7458486..c0670db1f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -16,19 +16,12 @@ import monarch from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer from monarch._src.actor.shape import NDSlice, Shape -from monarch.actor import ( - Actor, - endpoint, - get_or_spawn_controller, - HostMesh, - ProcMesh, - this_host, -) +from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host from monarch.tools import commands from monarch.tools.components import hyperactor from monarch.tools.config import Config -from forge.observability.metric_actors import GlobalLoggingActor, LocalFetcherActor +from forge.observability.metric_actors import setup_metric_logger from forge.types import ProcessConfig @@ -151,24 +144,6 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: server_name, ) - async def _setup_logging(self, procs: ProcMesh) -> None: - """Spawn and register local fetcher for metric logging on each process. - - When any process is created, we create for each rank a LocalFetcherActor and - store it at GlobalLoggingActor. Backends (e.g. wandb) should be eagerly instantiated - later in main by calling `global_logger.initialize_backends.call_one(logging_config)` - """ - - local_fetcher_actor = await procs.spawn( - "local_fetcher_actor", LocalFetcherActor - ) - procs._local_fetcher = local_fetcher_actor - - global_logger = await get_or_spawn_controller( - "global_logger", GlobalLoggingActor - ) - await global_logger.register_fetcher.call_one(local_fetcher_actor, procs) - async def get_proc_mesh( self, num_procs: int, with_gpus: bool = False, num_hosts: int | None = None ): @@ -243,7 +218,7 @@ def bootstrap(gpu_ids: list[str]): self._proc_server_map[procs] = server_name # Spawn local logging actor on each process and register with global logger - await self._setup_logging(procs) + _ = await setup_metric_logger(procs) return procs @@ -252,9 +227,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async with self._lock: # Deregister local logger from global logger if hasattr(proc_mesh, "_local_fetcher"): - global_logger = await get_or_spawn_controller( - "global_logger", GlobalLoggingActor - ) + global_logger = await setup_metric_logger(proc_mesh) await global_logger.deregister_fetcher.call_one(proc_mesh) if hasattr(proc_mesh, "_gpu_ids"): diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 10787a1d0..4f630b8af 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .metric_actors import get_metric_logger, GlobalLoggingActor, LocalFetcherActor +from .metric_actors import GlobalLoggingActor, LocalFetcherActor, setup_metric_logger from .metrics import ( ConsoleBackend, # Utility functions @@ -32,7 +32,7 @@ "reduce_metrics_states", "get_actor_name_with_rank", "get_logger_backend_class", - "get_metric_logger", + "setup_metric_logger", # Enums "ReductionType", # Actor classes diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 79b1f954b..53cca81a3 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -8,7 +8,7 @@ import logging from typing import Any, Dict, Optional -from monarch.actor import Actor, endpoint, get_or_spawn_controller +from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc from forge.observability.metrics import ( get_logger_backend_class, @@ -21,19 +21,89 @@ _global_logger = None -async def get_metric_logger(): - """Get or spawn the global logging actor. +async def setup_metric_logger( + proc_mesh: ProcMesh | None = None, +) -> "GlobalLoggingActor": + """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), + if not already initialized, registers it with the GlobalLoggingActor and returns the + GlobalLoggingActor instance. + + There are primarily two ways to use this function: + 1. In the main process, call `setup_metric_logger()` to get the global logger. + 2. In service processes, call `setup_metric_logger(proc_mesh)` to register the + local fetcher with the global logger. + + Args: + proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, + uses `monarch.actor.this_proc()`. Returns: - GlobalLoggingActor: The global logging actor instance. + GlobalLoggingActor: The global logging controller. + + Raises: + ValueError: If the logging state is inconsistent, i.e. the fetcher is already + registered, but only in the process or the global logger. Example: - mlogger = await get_global_logger() - # spawn your processes ... - await mlogger.init_backends(config) + from forge.observability.metric_actors import setup_metric_logger + from forge.observability.metrics import record_metric + + # Main process setup + mlogger = await setup_metric_logger() + + # Initialize services... + policy = await Policy.as_service(...) + + # Initialize logging backends after all local fetchers are registered + # so each rank can have its own. + await mlogger.init_backends({ + "console": {"reduce_across_ranks": True}, + "wandb": {"project": "my_project", "reduce_across_ranks": False} + }) + + # Training loop + for step in range(max_steps): + record_metric("loss", 1.2, step, reduction_type=ReductionType.MEAN) + # ... training code with record_metric() calls ... + await mlogger.flush(step) # Log metrics for this step + + # Shutdown + await mlogger.shutdown() """ - _global_logger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) - return _global_logger + # Get or create the singleton global logger + global _global_logger + if _global_logger is None: + _global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + global_logger = _global_logger + + # Determine process context + proc = proc_mesh if proc_mesh is not None else this_proc() + + # Check current state for consistency + proc_has_local_fetcher = hasattr(proc, "_local_fetcher") + global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) + + # Consistency check: both should be in sync + if proc_has_local_fetcher != global_logger_has_local_fetcher: + raise ValueError( + f"Inconsistent logging state for proc {proc}: " + f"proc has _local_fetcher={proc_has_local_fetcher}, " + f"but global_logger has registration={global_logger_has_local_fetcher}. " + f"This indicates a bug in logging setup/teardown. " + f"Both should be True (already setup) or both False (needs setup)." + ) + + # Setup local_fetcher_actor if needed + if not proc_has_local_fetcher: + local_fetcher_actor = await proc.spawn( + "local_fetcher_actor", LocalFetcherActor, global_logger + ) + await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) + proc._local_fetcher = local_fetcher_actor + + return global_logger class LocalFetcherActor(Actor): @@ -46,6 +116,7 @@ class LocalFetcherActor(Actor): def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: self.global_logger = global_logger + _is_initialized = False @endpoint async def flush( @@ -61,7 +132,6 @@ async def flush( Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ - collector = MetricCollector() result = await collector.flush(step, return_state=return_state) return result @@ -73,7 +143,6 @@ async def init_backends( config: Dict[str, Any], ): """Init local (per-rank) logger backends and MetricCollector.""" - collector = MetricCollector() await collector.init_backends(metadata_per_primary_backend, config) @@ -161,7 +230,7 @@ async def init_backends(self, config: Dict[str, Any]): await asyncio.gather(*tasks, return_exceptions=True) @endpoint - async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): + async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh): """Registers a fetcher with the global actor. Each key represents a process mesh. If there are 2 processes, each with 2 replicas with N gpus, we would have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" @@ -175,11 +244,11 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str): ) @endpoint - async def deregister_fetcher(self, name: str): + async def deregister_fetcher(self, name: str | ProcMesh): if name not in self.fetchers: logger.warning( f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." - "Available fetchers: {self.fetchers.keys()}" + f"Available fetchers: {self.fetchers.keys()}" ) return del self.fetchers[name] @@ -246,6 +315,11 @@ async def flush(self, step: int): ) in self.global_logger_backends.items(): await logger_backend.log(reduced_metrics, step) + @endpoint + def has_fetcher(self, name: str | ProcMesh) -> bool: + """Check if a fetcher is registered with the given name.""" + return name in self.fetchers + @endpoint def get_fetcher_count(self) -> int: return len(self.fetchers)