diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f2b7f6e7d..f256c5687 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -8,6 +8,7 @@ import asyncio import pprint +import time import uuid from dataclasses import dataclass from typing import Any, Callable @@ -16,6 +17,10 @@ import torch.nn.functional as F import torchstore as ts from datasets import load_dataset +from forge.actors._torchstore_utils import ( + get_dcp_whole_state_dict_key, + get_param_prefix, +) from forge.actors.policy import Policy from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer @@ -239,6 +244,23 @@ async def pad_token(self): return self._tokenizer.pad_token_id +async def drop_weights(version: int): + print(f"Dropping weights @ version {version}") + start_time = time.perf_counter() + prefix = get_param_prefix(version) + matching_keys = await ts.keys(prefix) + # TODO: once we have something like `get_meta()` in torchstore, we can just + # query the type of the object instead of relying on keys. + dcp_key = get_dcp_whole_state_dict_key(version) + if dcp_key in matching_keys: + dcp_handle = await ts.get(dcp_key) + dcp_handle.drop() + for key in matching_keys: + await ts.delete(key) + elapsed = time.perf_counter() - start_time + print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") + + async def main(cfg: DictConfig): """Main GRPO training loop with rollout and training processes.""" group_size = cfg.group_size @@ -362,6 +384,8 @@ async def continuous_training(): mlogger.log("loss/training_step", loss, training_step) await trainer.push_weights.fanout(training_step) await policy.update_weights.fanout(training_step) + if training_step >= 2: + await drop_weights(training_step - 1) print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 04ea8efe9..7f1a65e1a 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -1,5 +1,5 @@ # Grouped Relative Policy Optimization (GRPO) -# >>> python -m apps.grpo.qwen3_1_7b --config apps/grpo/qwen3_1_7b.yaml +# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml # Global configuration group_size: 8 diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 51ca387e5..916990c62 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -19,7 +19,7 @@ dataset: # Policy configuration policy: - use_vllm_builtin_load: false + use_vllm_builtin_load: true engine_config: model: ${model} tensor_parallel_size: 2 @@ -33,8 +33,8 @@ policy: # Trainer configuration trainer: - vllm_tp_DEPRECATED: ${policy.engine_config.tensor_parallel_size} - use_vllm_builtin_load: false + use_dcp: true + use_vllm_builtin_load: true model: name: qwen3 flavor: 8B diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py new file mode 100644 index 000000000..cd542df44 --- /dev/null +++ b/apps/toy_metrics/main.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio + +import logging +import time + +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import shutdown +from forge.observability.metric_actors import setup_metric_logger +from forge.observability.metrics import record_metric, ReductionType + +from monarch.actor import current_rank, endpoint + +logging.basicConfig(level=logging.DEBUG) + + +class TrainActor(ForgeActor): + """Example training actor that records loss metrics.""" + + @endpoint + async def train_step(self, step: int): + rank = current_rank().rank + value = rank * 1000 + 100 * step + print(f"[TRAIN] Rank {rank}: Step {step}, loss={value}") + record_metric("train/loss", value) + + +class GeneratorActor(ForgeActor): + """Example generation actor that records token count metrics.""" + + @endpoint + async def generate_step(self, step: int, substep: int): + rank = current_rank().rank + value = rank * 1000 + step * 100 + substep * 10 + print(f"[GEN] Rank {rank}: Step {step}.{substep}, tokens={value}") + record_metric("generate/tokens", value, ReductionType.SUM) + + +# Main +async def main(): + """Example demonstrating distributed metric logging with different backends.""" + group = f"grpo_exp_{int(time.time())}" + + # Config format: {backend_name: backend_config_dict} + # Each backend can specify reduce_across_ranks to control distributed logging behavior + config = { + "console": {"reduce_across_ranks": True}, + "wandb": { + "project": "my_project", + "group": group, + "reduce_across_ranks": True, + # Only useful if NOT reduce_across_ranks. + "share_run_id": False, # Share run ID across ranks -- Not recommended. + }, + } + + service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} + mlogger = await setup_metric_logger() + + # Spawn services first (triggers registrations via provisioner hook) + trainer = await TrainActor.options(**service_config).as_service() + generator = await GeneratorActor.options(**service_config).as_service() + + # Now init config on global (inits backends eagerly across fetchers) + await mlogger.init_backends.call_one(config) + + for i in range(3): + print(f"\n=== Global Step {i} ===") + await trainer.train_step.fanout(i) + for sub in range(3): + await generator.generate_step.fanout(i, sub) + await mlogger.flush.call_one(i) + + # shutdown + await mlogger.shutdown.call_one() + + await asyncio.gather( + trainer.shutdown(), + generator.shutdown(), + ) + + await shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/actors/_torchstore_utils.py index 558667fcb..bc0d55c3b 100644 --- a/src/forge/actors/_torchstore_utils.py +++ b/src/forge/actors/_torchstore_utils.py @@ -3,19 +3,49 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import shutil from dataclasses import dataclass import torch import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.metadata import Metadata as DcpMeta +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + KEY_DELIM = "." +DCP_WHOLE_STATE_TAG = "dcp_whole_state_dict" @dataclass class DcpHandle: - checkpoint_id: str = "" + checkpoint_id: str | None = None metadata: DcpMeta | None = None + param_names: list[str] | None = None + + def drop(self) -> None: + if self.checkpoint_id is None: + raise ValueError("Dropping a null DcpHandle") + if self.checkpoint_id.startswith("manifold://"): + # Probably don't need to delete the checkpoint if it's on manifold + logger.warning( + f"Skipping deletion of {self.checkpoint_id} since it's on manifold" + ) + self.checkpoint_id = None + self.metadata = None + self.param_names = None + return + + try: + shutil.rmtree(self.checkpoint_id, ignore_errors=False) + logger.debug(f"Removed old weights at {self.checkpoint_id}") + except OSError as e: + logger.error(f"Error deleting {self.checkpoint_id}: {e}") + finally: + self.checkpoint_id = None + self.metadata = None + self.param_names = None def load_tensor_from_dcp(handle: DcpHandle, param_name) -> torch.Tensor: @@ -35,3 +65,7 @@ def get_param_key(policy_version: int, name: str) -> str: def extract_param_name(key: str) -> str: return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) + + +def get_dcp_whole_state_dict_key(policy_version: int) -> str: + return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index e69290346..8963d178b 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -43,8 +43,8 @@ from vllm.worker.worker_base import WorkerWrapperBase from forge.actors._torchstore_utils import ( - DcpHandle, extract_param_name, + get_dcp_whole_state_dict_key, get_param_key, get_param_prefix, load_tensor_from_dcp, @@ -481,8 +481,6 @@ class PolicyWorker(ForgeActor): # TODO: remove this later since no plumbing exists to change this value. # Also, whether to use dcp or not can be inferred from torchstore get() call. use_dcp: bool = True - # Cache hf param names on first update call. - hf_param_names = [] # used for tesing purposes only _test_prev_params = {} @@ -560,28 +558,31 @@ async def update(self, version: int): logger.debug(f"{prefix=}") matching_keys = await ts.keys(prefix) logger.debug(f"{matching_keys=}") - if not self.hf_param_names: - self.hf_param_names = [extract_param_name(key) for key in matching_keys] + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) loaded_weights = set() - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. start = time.perf_counter() - for name in self.hf_param_names: - param_key = get_param_key(version, name) - tensor_or_handle = await ts.get(param_key) - if isinstance(tensor_or_handle, torch.Tensor): - param = tensor_or_handle - elif isinstance(tensor_or_handle, DcpHandle): - logger.info(f"Loading {name} from DCP with handle {tensor_or_handle}") - param = load_tensor_from_dcp(tensor_or_handle, name) - logger.info(f"Loaded {name} from DCP with handle {tensor_or_handle}") - else: - raise RuntimeError( - f"Unexpected type for {param_key}: {type(tensor_or_handle)}" - ) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) + # Entire state dict is stored in a single DCP handle + if dcp_whole_state_dict_key in matching_keys: + logger.info( + f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" + ) + dcp_handle = await ts.get(dcp_whole_state_dict_key) + hf_param_names = dcp_handle.param_names + for name in hf_param_names: + param = load_tensor_from_dcp(dcp_handle, name) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + else: # Load each parameter from torchstore directly without DCP + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) logger.info( f"[PolicyWorker::update] Updated {len(loaded_weights)} parameters, took {time.perf_counter() - start} seconds" ) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index de6660b05..3fd0162a0 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -36,7 +36,11 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import DcpHandle, get_param_key +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) from forge.controller import ForgeActor from forge.data.utils import batch_to_device @@ -328,11 +332,12 @@ async def _push_weights_DEPRECATED( # noqa: N802 @endpoint async def push_weights(self, policy_version: int) -> None: """Push weights to torchstore in HF format.""" + logger.info(f"Pushing weights for policy version {policy_version}") if not self.use_vllm_builtin_load: return await self._push_weights_DEPRECATED( policy_version, self.vllm_tp_DEPRECATED ) - + start_time = time.perf_counter() if "model" not in self.engine.checkpointer.states: raise RuntimeError("Model state not found in checkpointer state") @@ -344,21 +349,24 @@ async def push_weights(self, policy_version: int) -> None: ) hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) if self.use_dcp: - # we could use dcp.save() to save the whole state dict, - # but I don't want too much deviation between the two code paths - for name, param in hf_state_dict.items(): - key = get_param_key(policy_version, name) - dcp_id = f"{self.dcp_path}/{key}" - metadata = dcp.save( - checkpoint_id=dcp_id, - state_dict={name: param}, - ) - dcp_handle = DcpHandle(checkpoint_id=dcp_id, metadata=metadata) - await ts.put(key, dcp_handle) + key = get_dcp_whole_state_dict_key(policy_version) + dcp_id = f"{self.dcp_path}/{key}" + storage_writer = torch.distributed.checkpoint.FileSystemWriter( + dcp_id, single_file_per_rank=False, thread_count=8 + ) + metadata = dcp.save(storage_writer=storage_writer, state_dict=hf_state_dict) + dcp_handle = DcpHandle( + checkpoint_id=dcp_id, + metadata=metadata, + param_names=hf_state_dict.keys(), + ) + await ts.put(key, dcp_handle) else: for name, param in hf_state_dict.items(): key = get_param_key(policy_version, name) await ts.put(key, param) + end_time = time.perf_counter() + logger.info("Completed weights push in %.2f seconds", end_time - start_time) @endpoint async def cleanup(self) -> None: diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index 71d35c433..8f7c2f420 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from .actor import ForgeActor from .proc_mesh import get_proc_mesh, stop_proc_mesh @@ -24,9 +23,4 @@ async def spawn_actors( return actors -__all__ = [ - "spawn_actors", - "stop_proc_mesh", - "get_proc_mesh", - "ForgeActor", -] +__all__ = ["spawn_actors", "stop_proc_mesh", "get_proc_mesh", "ForgeActor"] diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 26d51ea5c..c0670db1f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -21,6 +21,8 @@ from monarch.tools.components import hyperactor from monarch.tools.config import Config +from forge.observability.metric_actors import setup_metric_logger + from forge.types import ProcessConfig logger = logging.getLogger(__name__) @@ -215,11 +217,19 @@ def bootstrap(gpu_ids: list[str]): self._server_names.append(server_name) self._proc_server_map[procs] = server_name + # Spawn local logging actor on each process and register with global logger + _ = await setup_metric_logger(procs) + return procs async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" async with self._lock: + # Deregister local logger from global logger + if hasattr(proc_mesh, "_local_fetcher"): + global_logger = await setup_metric_logger(proc_mesh) + await global_logger.deregister_fetcher.call_one(proc_mesh) + if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] gpu_manager.release_gpus(proc_mesh._gpu_ids) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py new file mode 100644 index 000000000..4f630b8af --- /dev/null +++ b/src/forge/observability/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .metric_actors import GlobalLoggingActor, LocalFetcherActor, setup_metric_logger +from .metrics import ( + ConsoleBackend, + # Utility functions + get_actor_name_with_rank, + get_logger_backend_class, + # Backend classes + LoggerBackend, + MaxAccumulator, + MeanAccumulator, + # Accumulator classes + MetricAccumulator, + MetricCollector, + MinAccumulator, + record_metric, + reduce_metrics_states, + ReductionType, + StdAccumulator, + SumAccumulator, + WandbBackend, +) + +__all__ = [ + # Main API functions + "record_metric", + "reduce_metrics_states", + "get_actor_name_with_rank", + "get_logger_backend_class", + "setup_metric_logger", + # Enums + "ReductionType", + # Actor classes + "GlobalLoggingActor", + "LocalFetcherActor", + # Collector + "MetricCollector", + # Backend classes + "LoggerBackend", + "ConsoleBackend", + "WandbBackend", + # Accumulator classes + "MetricAccumulator", + "MeanAccumulator", + "SumAccumulator", + "MaxAccumulator", + "MinAccumulator", + "StdAccumulator", +] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py new file mode 100644 index 000000000..53cca81a3 --- /dev/null +++ b/src/forge/observability/metric_actors.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +from typing import Any, Dict, Optional + +from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc + +from forge.observability.metrics import ( + get_logger_backend_class, + MetricCollector, + reduce_metrics_states, +) + +logger = logging.getLogger(__name__) + +_global_logger = None + + +async def setup_metric_logger( + proc_mesh: ProcMesh | None = None, +) -> "GlobalLoggingActor": + """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), + if not already initialized, registers it with the GlobalLoggingActor and returns the + GlobalLoggingActor instance. + + There are primarily two ways to use this function: + 1. In the main process, call `setup_metric_logger()` to get the global logger. + 2. In service processes, call `setup_metric_logger(proc_mesh)` to register the + local fetcher with the global logger. + + Args: + proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, + uses `monarch.actor.this_proc()`. + + Returns: + GlobalLoggingActor: The global logging controller. + + Raises: + ValueError: If the logging state is inconsistent, i.e. the fetcher is already + registered, but only in the process or the global logger. + + Example: + from forge.observability.metric_actors import setup_metric_logger + from forge.observability.metrics import record_metric + + # Main process setup + mlogger = await setup_metric_logger() + + # Initialize services... + policy = await Policy.as_service(...) + + # Initialize logging backends after all local fetchers are registered + # so each rank can have its own. + await mlogger.init_backends({ + "console": {"reduce_across_ranks": True}, + "wandb": {"project": "my_project", "reduce_across_ranks": False} + }) + + # Training loop + for step in range(max_steps): + record_metric("loss", 1.2, step, reduction_type=ReductionType.MEAN) + # ... training code with record_metric() calls ... + await mlogger.flush(step) # Log metrics for this step + + # Shutdown + await mlogger.shutdown() + """ + # Get or create the singleton global logger + global _global_logger + if _global_logger is None: + _global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + global_logger = _global_logger + + # Determine process context + proc = proc_mesh if proc_mesh is not None else this_proc() + + # Check current state for consistency + proc_has_local_fetcher = hasattr(proc, "_local_fetcher") + global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) + + # Consistency check: both should be in sync + if proc_has_local_fetcher != global_logger_has_local_fetcher: + raise ValueError( + f"Inconsistent logging state for proc {proc}: " + f"proc has _local_fetcher={proc_has_local_fetcher}, " + f"but global_logger has registration={global_logger_has_local_fetcher}. " + f"This indicates a bug in logging setup/teardown. " + f"Both should be True (already setup) or both False (needs setup)." + ) + + # Setup local_fetcher_actor if needed + if not proc_has_local_fetcher: + local_fetcher_actor = await proc.spawn( + "local_fetcher_actor", LocalFetcherActor, global_logger + ) + await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) + proc._local_fetcher = local_fetcher_actor + + return global_logger + + +class LocalFetcherActor(Actor): + """Thin per-process actor used to trigger MetricCollector singleton + operations without direct access. It is what GlobalLoggingActor + uses to broadcast inits/flushes across ranks. + + GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector + """ + + def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: + self.global_logger = global_logger + _is_initialized = False + + @endpoint + async def flush( + self, step: int, return_state: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. + + Args: + step (int): train step used by backends to align all metrics on the same x-axis + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + If False, returns empty dict, else returns the state of all metrics collected. + Returns: + Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, + e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. + """ + collector = MetricCollector() + result = await collector.flush(step, return_state=return_state) + return result + + @endpoint + async def init_backends( + self, + metadata_per_primary_backend: Dict[str, Dict[str, Any]], + config: Dict[str, Any], + ): + """Init local (per-rank) logger backends and MetricCollector.""" + collector = MetricCollector() + await collector.init_backends(metadata_per_primary_backend, config) + + @endpoint + async def shutdown(self): + + collector = MetricCollector() + await collector.shutdown() + + +class GlobalLoggingActor(Actor): + """Coordinates metric logging across all ranks for every training step. + + Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), + for per-rank and/or global reduction logging modes. + + If a backend config has flag `reduce_across_ranks=False`, an instance of the backend + is initialized per-rank, otherwise it is done once globally. + + This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor + is automatically spawned per-rank in `forge.controller.provisioner.py` and registered + with this actor. The LocalFetcherActor is responsible for instantiating + the per-rank MetricCollector. + + In summary, the flow is: + - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector + - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + """ + + def __init__(self): + self.fetchers: Dict[str, LocalFetcherActor] = {} + self.config: Dict[str, Any] | None = None + self.global_logger_backends: Dict[str, "LoggerBackend"] = {} + self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} + + @endpoint + async def init_backends(self, config: Dict[str, Any]): + """ + Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors + in all registered fetchers. + + A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source + for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. + + The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, + a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, + and will log independently, i.e. each rank will have its own run in wandb. + + Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states + and reduce them to a single value, which will be logged by the primary backend in this controller. + + Args: + config (Dict[str, Any]): Config for metric logging where keys are backend names, + e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} + """ + self.config = config + + for backend_name, backend_config in config.items(): + backend = get_logger_backend_class(backend_name)(backend_config) + await backend.init(role="global") + + # Extract metadata from primary logger to be shared with secondary loggers + # and store it + reduce_across_ranks = backend_config.get("reduce_across_ranks", True) + if not reduce_across_ranks: + primary_backend_metadata = ( + backend.get_metadata_for_secondary_ranks() or {} + ) + self.metadata_per_primary_backend[ + backend_name + ] = primary_backend_metadata + + # Store global logger backends + if reduce_across_ranks: + self.global_logger_backends[backend_name] = backend + + # Eager init collectors on all registered fetchers in parallel, passing primary states and config + if self.fetchers: + tasks = [ + fetcher.init_backends.call( + self.metadata_per_primary_backend, self.config + ) + for fetcher in self.fetchers.values() + ] + await asyncio.gather(*tasks, return_exceptions=True) + + @endpoint + async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh): + """Registers a fetcher with the global actor. Each key represents a process mesh. + If there are 2 processes, each with 2 replicas with N gpus, we would + have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" + self.fetchers[name] = fetcher + + # Self-init for respawned actors + if self.config: + logger.debug(f"Initializing new LocalFetcherActor {name}") + await fetcher.init_backends.call( + self.metadata_per_primary_backend, self.config + ) + + @endpoint + async def deregister_fetcher(self, name: str | ProcMesh): + if name not in self.fetchers: + logger.warning( + f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." + f"Available fetchers: {self.fetchers.keys()}" + ) + return + del self.fetchers[name] + + @endpoint + async def flush(self, step: int): + """ + Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors + log to local backends and return states if needed for cross-rank reduction. + + Args: + step (int): Global step for logging. + """ + if not self.fetchers: + return + + config = self.config + # if reduce_across_ranks=True, we need to reduce the states from all ranks + # and log with the primary backend + requires_reduce = any( + backend_config.get("reduce_across_ranks", True) + for backend_config in config.values() + ) + + logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + + # Broadcast flush to all fetchers + results = await asyncio.gather( + *[ + f.flush.call(step, return_state=requires_reduce) + for f in self.fetchers.values() + ], + return_exceptions=True, + ) + + if requires_reduce: + # Handle exceptions and extract values from ValueMesh results + all_local_states = [] + for result in results: + if isinstance(result, Exception): + logger.warning(f"Flush failed on a fetcher: {result}") + continue + + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) + + if not all_local_states: + logger.warning(f"No states to reduce for step {step}") + return + + # Reduce + reduced_metrics = reduce_metrics_states(all_local_states) + + # Log to each global logger_backend + for ( + logger_backend_name, + logger_backend, + ) in self.global_logger_backends.items(): + await logger_backend.log(reduced_metrics, step) + + @endpoint + def has_fetcher(self, name: str | ProcMesh) -> bool: + """Check if a fetcher is registered with the given name.""" + return name in self.fetchers + + @endpoint + def get_fetcher_count(self) -> int: + return len(self.fetchers) + + @endpoint + async def shutdown(self): + # Finish per-rank logger_backends via fetchers + if self.fetchers: + tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] + await asyncio.gather(*tasks, return_exceptions=True) + # Finish global logger_backends + for logger_backend_name, logger_backend in self.global_logger_backends.items(): + await logger_backend.finish() diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py new file mode 100644 index 000000000..2f2b70494 --- /dev/null +++ b/src/forge/observability/metrics.py @@ -0,0 +1,690 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional + +from monarch.actor import context, current_rank + + +logger = logging.getLogger(__name__) + + +class ReductionType(Enum): + MEAN = "mean" + SUM = "sum" + MAX = "max" + MIN = "min" + STD = "std" + + @property + def accumulator_class(self): + mapping = { + ReductionType.MEAN: MeanAccumulator, + ReductionType.SUM: SumAccumulator, + ReductionType.MAX: MaxAccumulator, + ReductionType.MIN: MinAccumulator, + ReductionType.STD: StdAccumulator, + } + return mapping[self] + + +def get_actor_name_with_rank() -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. + """ + # Add more defensive checks + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + rank_name = "UnknownActor" # fallback + if len(parts) >= 2: + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Extract world ID and proc rank + world_id = world_part.split("[")[0] if "[" in world_part else world_part + + # Extract clean actor name (remove "Configured" suffix if present) + if "[" in actor_part: + actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" + if actor_name.endswith("Configured"): + actor_name = actor_name[:-10] # Remove "Configured" + else: + actor_name = actor_part + + # Use last 4 characters of world_id as replica identifier + # This is deterministic, readable, and works for any number of replicas + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + # Use current_rank().rank as the local rank within the replica + local_rank = rank.rank + + rank_name = f"{actor_name}_{replica_id}_r{local_rank}" + + return rank_name + + +def record_metric( + key: str, value: Any, reduction: ReductionType = ReductionType.MEAN +) -> None: + """ + Records a metric value for later reduction and logging. + + Relies on a per-rank MetricCollector singleton for ease of use, i.e. + call `record_metric` anywhere in the code without moving the + collector from function to function. + + The collector methods are triggered per-rank by a + `forge.observability.metric_actors.LocalFetcherActor`, instantiated + during actor initialization. + + Records are flushed when `forge.observability.metric_actors.GlobalLoggingActor.flush()` + is called, typically triggered by the training loop at regular intervals. + """ + collector = MetricCollector() + collector.push(key, value, reduction) + + +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: + """Reduce metric accumulators states to a single value per metric. + + Can be used when reducing metrics across ranks or services, as merging + states is more precise than merging locally reduced metrics. + + Args: + states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics, + normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. + + Returns: + Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + + Example: + states = [ + {"loss": {"count": 5, "sum": 14, "reduction_type": ReductionType.MEAN}}, + {"loss": {"count": 10, "sum": 16, "reduction_type": ReductionType.MEAN}}, + ] + reduce_metrics_states(states) + >>> {"loss": 2.0} + + Raises: + ValueError: on mismatched reduction types for the same metric key. + """ + if not states: + return {} + + # Collect unique keys across all + all_keys = set(k for state in states for k in state) + + reduced_metrics = {} + for key in all_keys: + metric_states = [state.get(key) for state in states if key in state] + if not metric_states: + continue + + first_reduction_type = metric_states[0]["reduction_type"] + + # Check consistency + for state in metric_states: + if state["reduction_type"] != first_reduction_type: + raise ValueError( + f"Mismatched reduction types for key '{key}': {first_reduction_type} vs {state['reduction_type']}" + ) + + metric_accumulator = ReductionType(first_reduction_type).accumulator_class + reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) + reduced_metrics[key] = reduced_value + + return reduced_metrics + + +################ +# Accumulators # +################ + + +class MetricAccumulator(ABC): + """Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them.""" + + def __init__(self, reduction: ReductionType): + self.reduction_type = reduction + + @abstractmethod + def append(self, value: Any) -> None: + """Updates accumulator with new value (e.g., adds to sum and count for MEAN).""" + pass + + @abstractmethod + def get_value(self) -> Any: + """Returns locally reduced value (e.g., sum/count for MEAN).""" + pass + + @abstractmethod + def get_state(self) -> Dict[str, Any]: + """Returns serializable state for cross-rank merge (e.g., {'sum': 10.0, 'count': 5}).""" + pass + + @classmethod + @abstractmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> Any: + """Merges states from multiple ranks into single reduced value (e.g., total_sum/total_count for MEAN).""" + pass + + @abstractmethod + def reset(self) -> None: + """Clears for next accumulation cycle (e.g., sum=0, count=0 for MEAN).""" + pass + + +class MeanAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.count += 1 + + def get_value(self) -> float: + return self.sum / self.count if self.count > 0 else 0.0 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_count = sum(s["count"] for s in states) + return total_sum / total_count if total_count > 0 else 0.0 + + def reset(self) -> None: + self.sum = 0.0 + self.count = 0 + + +class SumAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.total = 0.0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.total += v + + def get_value(self) -> float: + return self.total + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "total": self.total} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return sum(s["total"] for s in states) + + def reset(self) -> None: + self.total = 0.0 + + +class MaxAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.max_val = float("-inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.max_val = max(self.max_val, v) + + def get_value(self) -> float: + return self.max_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return max(s["max_val"] for s in states) + + def reset(self) -> None: + self.max_val = float("-inf") + + +class MinAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.min_val = float("inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.min_val = min(self.min_val, v) + + def get_value(self) -> float: + return self.min_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return min(s["min_val"] for s in states) + + def reset(self) -> None: + self.min_val = float("inf") + + +class StdAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.sum_sq += v * v + self.count += 1 + + def get_value(self) -> float: + if self.count == 0: + return 0.0 + if self.count == 1: + return 0.0 + mean = self.sum / self.count + variance = (self.sum_sq / self.count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "sum_sq": self.sum_sq, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_sum_sq = sum(s["sum_sq"] for s in states) + total_count = sum(s["count"] for s in states) + if total_count == 0: + return 0.0 + if total_count == 1: + return 0.0 + mean = total_sum / total_count + variance = (total_sum_sq / total_count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def reset(self) -> None: + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + +############# +# Collector # +############# + + +class MetricCollector: + """Per-rank singleton for accumulating, retrieving and flushing metrics to backends. + + A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, + the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally, + in the GlobalLoggingActor. + + - Ensures one instance per process; actors call record_metric() which delegates here. + - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; + - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also + return non-reduced states for global aggregation. This can be different for each backend. + - Resets accumulators post-flush to avoid leaks across train steps; + """ + + _instances: Dict[int, "MetricCollector"] = {} + + def __new__(cls): + """Singleton per-rank, ensures one instance per process.""" + rank = current_rank().rank + + if rank not in cls._instances: + inst = super().__new__(cls) + cls._instances[rank] = inst + inst._singleton_rank = rank + else: + inst = cls._instances[rank] + if inst._singleton_rank != rank: + raise ValueError( + f"Singleton expected rank {inst._singleton_rank}, but saw {rank}" + ) + return inst + + def __init__(self): + if hasattr(self, "_is_initialized"): + return + + self.accumulators: Dict[str, MetricAccumulator] = {} + self.rank = current_rank().rank + self.logger_backends: List[LoggerBackend] = [] + self._is_initialized = False + + async def init_backends( + self, + metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]], + config: Dict[str, Any], + ) -> None: + """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, + the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated + once globally. + + Args: + metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary + logger backend, e.g., {"wandb": {"run_id": "abc123"}}. + config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + """ + if self._is_initialized: + logger.debug(f"Rank {self.rank}: MetricCollector already initialized") + return + + # instantiate local backends if any + for backend_name, backend_config in config.items(): + if backend_config.get("reduce_across_ranks", True): + continue # Skip local backend instantiation and use global instead + primary_state = metadata_per_primary_backend.get(backend_name, {}) + logger_backend = get_logger_backend_class(backend_name)(backend_config) + await logger_backend.init( + role="local", primary_logger_metadata=primary_state + ) + self.logger_backends.append(logger_backend) + + self._is_initialized = True + + def push( + self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN + ) -> None: + if not self._is_initialized: + raise ValueError("Collector not initialized—call init first") + + if key not in self.accumulators: + self.accumulators[key] = reduction.accumulator_class(reduction) + + self.accumulators[key].append(value) + + async def flush( + self, step: int, return_state: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. + + Args: + step (int): Step used by backends to align metrics on the same x-axis + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + If False, returns empty dict, else returns the state of all metrics collected. + Returns: + Dict[str, Dict[str, Dict[str, Any]]]: Dict of {metric_key: metric_state}, + e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. + """ + if not self._is_initialized: + logger.debug( + f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + ) + return {} + + if not self.accumulators: + logger.debug( + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" + ) + return {} + + # Snapshot states and reset immediately + states = {} + for key, acc in self.accumulators.items(): + states[key] = acc.get_state() + acc.reset() + + # Reduce metrics from states for logging if any per-rank backend + if self.logger_backends: + metrics = {} + for key, state in states.items(): + acc_class = ReductionType(state["reduction_type"]).accumulator_class + metrics[key] = acc_class.get_reduced_value_from_states([state]) + + # Log to local logger_backends + for logger_backend in self.logger_backends: + await logger_backend.log(metrics, step) + + return states if return_state else {} + + async def shutdown(self): + """Shutdown logger_backends if initialized.""" + if not self._is_initialized: + logger.debug( + f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" + ) + return + + for logger_backend in self.logger_backends: + await logger_backend.finish() + + +########### +# Backends # +########### + + +class LoggerBackend(ABC): + """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" + + def __init__(self, logger_backend_config: Dict[str, Any]): + self.logger_backend_config = logger_backend_config + + @abstractmethod + async def init( + self, + role: str, + primary_logger_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initializes backend, e.g. wandb.run.init(). + + Args: + role (str): "global" (controller/primary) or "local" (per-rank/secondary). + Can be used to behave differently for primary vs secondary roles. + primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for + backend that required shared info, e.g. {"shared_run_id": "abc123"}. + + Raises: ValueError if missing metadata for shared local init. + """ + if primary_logger_metadata is None: + primary_logger_metadata = {} + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + async def finish(self) -> None: + pass + + def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + return None + + +class ConsoleBackend(LoggerBackend): + """Simple console logging of metrics.""" + + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + + async def init( + self, + role: str, + primary_logger_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + self.prefix = ( + get_actor_name_with_rank() + if self.logger_backend_config.get("reduce_across_ranks", True) + else "GLOBAL" + ) + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") + for key, value in metrics.items(): + logger.info(f" {key}: {value}") + logger.info("==============================\n") + + async def finish(self) -> None: + pass + + +class WandbBackend(LoggerBackend): + """ + Weights & Biases logging backend for distributed training. + + Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: + Track a single process: reduce_across_ranks=True + Track each process separately: reduce_across_ranks=False, share_run_id=False + Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + + Configuration: + reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). + If False, enables per-rank logging; then use share_run_id to pick mode. + share_run_id (bool, default False): Only used if reduce_across_ranks=False. + True -> shared run across ranks; False -> separate runs per rank. + project (str): WandB project name + group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" + """ + + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + self.project = logger_backend_config["project"] + self.group = logger_backend_config.get("group", "experiment_group") + self.name = None + self.run = None + self.reduce_across_ranks = logger_backend_config.get( + "reduce_across_ranks", True + ) + self.share_run_id = logger_backend_config.get("share_run_id", False) + + async def init( + self, + role: str, + primary_logger_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + + if primary_logger_metadata is None: + primary_logger_metadata = {} + + if role not in ["global", "local"]: + raise ValueError( + f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." + ) + + self.name = ( + get_actor_name_with_rank() if role == "local" else "global_controller" + ) + + # Default global mode: only inits on controller + if self.reduce_across_ranks: + if role != "global": + logger.debug( + f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." + ) + return + await self._init_global() + + # Per-rank modes based on share_run_id bool + elif role == "global" and self.share_run_id: + await self._init_shared_global() + + elif role == "local": + if self.share_run_id: + await self._init_shared_local(primary_logger_metadata) + else: + await self._init_per_rank() + + async def _init_global(self): + import wandb + + self.run = wandb.init(project=self.project, group=self.group) + + async def _init_per_rank(self): + import wandb + + self.run = wandb.init(project=self.project, group=self.group, name=self.name) + + async def _init_shared_global(self): + import wandb + + settings = wandb.Settings( + mode="shared", x_primary=True, x_label="controller_primary" + ) + self.run = wandb.init(project=self.project, group=self.group, settings=settings) + + async def _init_shared_local(self, primary_metadata: Dict[str, Any]): + import wandb + + shared_id = primary_metadata.get("shared_run_id") + if shared_id is None: + raise ValueError( + f"Shared ID required but not provided for {self.name} backend init" + ) + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) + self.run = wandb.init( + id=shared_id, + project=self.project, + group=self.group, + settings=settings, + ) + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.run: + log_data = {**metrics, "global_step": step} + self.run.log(log_data) + logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + else: + logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") + + def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]: + if self.run and not self.reduce_across_ranks and self.share_run_id: + return {"shared_run_id": self.run.id} + return {} + + async def finish(self) -> None: + if self.run: + self.run.finish() + logger.info(f"WandbBackend {self.name}: Finished run") + + +def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: + """Simple mapping between logger_backend type and its class + + Factory for backend classes from config; returns uninitialized class for role-based init. + """ + if cls_name == "console": + return ConsoleBackend + elif cls_name == "wandb": + return WandbBackend + else: + raise ValueError(f"Unknown logger backend type: {cls_name}") diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py new file mode 100644 index 000000000..cf16a8cf0 --- /dev/null +++ b/tests/integration_tests/conftest.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +import pytest + + +def str_to_bool(value): + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + elif value.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError(f"Boolean value expected, got '{value}'") + + +def pytest_addoption(parser): + """Add custom command line options for pytest.""" + parser.addoption( + "--config", + action="store", + default=None, + help="Path to YAML config file for sanity check tests", + ) + + parser.addoption( + "--use_dcp", + action="store", + type=str_to_bool, + default=None, + help="Overrides the YAML config `trainer.use_dcp` field.", + ) + + +@pytest.fixture +def config_path(request): + """Fixture to provide the config path from command line.""" + return request.config.getoption("--config") diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml new file mode 100644 index 000000000..4d3a56d04 --- /dev/null +++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml @@ -0,0 +1,72 @@ +# Global configuration +group_size: 8 +batch_size: 16 +max_req_tokens: 512 +max_res_tokens: 512 +model: "Qwen/Qwen3-1.7B" +off_by_n: 1 # Off by one by default + + +# Policy configuration +policy: + engine_config: + model: ${model} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_config: + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${batch_size} + seq_len: 2048 + max_norm: 1.0 + steps: 1000000 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# All resource allocations +services: + policy: + procs: ${policy.engine_config.tensor_parallel_size} + num_replicas: 1 + with_gpus: true + trainer: + procs: 1 + num_replicas: 1 + with_gpus: true diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml new file mode 100644 index 000000000..0ac915d2a --- /dev/null +++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml @@ -0,0 +1,74 @@ +# trainer tp = 2, policy tp = 4 + +# Global confOiguration +group_size: 8 +batch_size: 16 +max_req_tokens: 512 +max_res_tokens: 512 +model: "Qwen/Qwen3-1.7B" +off_by_n: 1 # Off by one by default + + +# Policy configuration +policy: + engine_config: + model: ${model} + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + enforce_eager: false + sampling_config: + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + flavor: 1.7B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${batch_size} + seq_len: 2048 + max_norm: 1.0 + steps: 1000000 + dtype: bfloat16 + gc_freq: 1 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 2 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 1 + disable_loss_parallel: true + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# All resource allocations +services: + policy: + procs: ${policy.engine_config.tensor_parallel_size} + num_replicas: 1 + with_gpus: true + trainer: + procs: 2 + num_replicas: 1 + with_gpus: true diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index bd0c7e17b..0968ef729 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -6,20 +6,21 @@ import asyncio import logging -from dataclasses import asdict from tempfile import TemporaryDirectory import pytest + import torch import torchstore as ts -from forge.actors.policy import EngineConfig, Policy, SamplingConfig +from forge.actors.policy import Policy from forge.actors.trainer import RLTrainer -from forge.controller.service import ServiceConfig +from forge.cli.config import resolve_hf_hub_paths from forge.controller.service.service import uuid from monarch.actor import endpoint +from omegaconf import DictConfig, OmegaConf requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), @@ -31,29 +32,15 @@ logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +""" +Run tests: -# Run tests: pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync:: +pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ + --config tests/integration_tests/artifacts/qwen3_1_7b_tp.yaml --use_dcp=false - -def get_configs( - worker_size: int, tp_size: int, model_name: str -) -> tuple[dict, ServiceConfig]: - engine_config = EngineConfig( - model=model_name, - tensor_parallel_size=tp_size, - pipeline_parallel_size=1, - enforce_eager=True, - ) - sampling_config = SamplingConfig( - n=3, - guided_decoding=True, - ) - policy_config = { - "engine_config": engine_config, - "sampling_config": sampling_config, - } - service_config = ServiceConfig(procs=worker_size, num_replicas=1, with_gpus=True) - return policy_config, service_config +pytest -s tests/integration_tests/test_policy_update.py::TestWeightSync::test_sanity_check \ + --config apps/grpo/qwen3_8b.yaml +""" class MockRLTrainer(RLTrainer): @@ -141,163 +128,92 @@ def validate_fn_all_zeros(prev_params, curr_model, logger) -> Exception | None: class TestWeightSync: - """Tests for weight sync between trainer and policy. Currently hardcoded to Qwen3-1.7B.""" - - model = "Qwen/Qwen3-1.7B" - - def default_trainer_cfg(self): - cached_dir = snapshot_download(repo_id=self.model) - return { - "model": { - "name": "qwen3", - "flavor": "1.7B", - }, - "checkpoint": { - "enable": True, - "folder": "/tmp/saved_checkpoints", - "initial_load_path": cached_dir, - "initial_load_in_hf": True, - }, - } + """Tests for weight sync between trainer and policy.""" - def default_trainer_cfg_tp(self): - # NB: TP size is set to 2. - cached_dir = snapshot_download(repo_id=self.model) - return { - "model": { - "name": "qwen3", - "flavor": "1.7B", - }, - "parallelism": {"tensor_parallel_degree": 2}, - "checkpoint": { - "enable": True, - "folder": "/tmp/saved_checkpoints", - "initial_load_path": cached_dir, - "initial_load_in_hf": True, - }, - } + def _load_config(self, config_path: str) -> DictConfig: + cfg = None + try: + cfg = OmegaConf.load(config_path) + except Exception as e: + pytest.fail(f"Failed to load config file {config_path}: {e}") + + assert isinstance(cfg, DictConfig) + + cfg = resolve_hf_hub_paths(cfg) + return cfg @pytest.mark.asyncio @requires_cuda - @pytest.mark.parametrize( - "use_dcp", [pytest.param(True, id="use_dcp"), pytest.param(False, id="no_dcp")] - ) - async def test_policy_update_single(self, use_dcp): + async def test_sanity_check(self, request): """ - Test the weight synchronization process between RLTrainer and Policy. + Sanity check for weight sync sharding between RLTrainer and Policy for a given model config. - This test performs the following steps: + The check performs the following steps: - Initialize trainer and push weights v0 (original huggingface ckpt) - Step the trainer, setting all weights to zero and push weights v1 - Load weights v0 and check the policy has all zero weights - Load weights v1 and check the policy has all the weights back - """ - trainer_worker_size = 1 - policy_worker_size = 1 - tp_size = 1 - - await ts.initialize() - - trainer_cfg = self.default_trainer_cfg() - trainer_cfg["use_dcp"] = use_dcp - with TemporaryDirectory(dir="/dev/shm/") as tmpdir: - if use_dcp: - trainer_cfg["dcp_path"] = tmpdir - policy_config, service_config = get_configs( - worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model - ) - policy, rl_trainer = await asyncio.gather( - *[ - Policy.options(**asdict(service_config)).as_service( - **policy_config - ), - MockRLTrainer.options( - procs=trainer_worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg), - ] + """ + # Test setup + config_path = request.config.getoption("--config", default=None) + if not config_path: + pytest.skip( + "No config file provided. Use --config to specify a YAML config file" ) - v0 = uuid.uuid4().int - v1 = v0 + 1 + use_dcp_override = request.config.getoption("--use_dcp") + cfg = self._load_config(config_path=config_path) - await rl_trainer.push_weights.fanout(policy_version=v0) - # Setting everything to zero - await rl_trainer.zero_out_model_states.fanout() - await rl_trainer.push_weights.fanout(policy_version=v1) - await policy._test_save_model_params.fanout() + trainer_proc_size = cfg.services.trainer.procs + policy_tp_size = cfg.policy.engine_config.tensor_parallel_size - # Sanity check that before update all the tests pass - all_errs = await policy._test_validate_model_params.fanout(validate_fn) - for errs in all_errs: - for _, e in errs.items(): - assert not e, f"Validation failed with exception: {e}" - - await policy.update_weights.fanout(policy_version=v1) - all_errs = await policy._test_validate_model_params.fanout( - validate_fn_all_zeros + if policy_tp_size != cfg.services.policy.procs: + pytest.fail( + f"Expect policy proc = {cfg.services.policy.procs} to be equal to tensor parallel size = {policy_tp_size}" ) - for errs in all_errs: - for _, e in errs.items(): - assert not e, f"Validation failed with exception: {e}" - # Reloading v0, getting back original weights - await policy.update_weights.fanout(policy_version=v0) - all_errs = await policy._test_validate_model_params.fanout(validate_fn) - for errs in all_errs: - for _, e in errs.items(): - assert not e, f"Validation failed with exception: {e}" + model_card = cfg.model - await ts.shutdown() + logger.info(f"Running sanity check with config: {config_path}") + logger.info(f"Model name: {model_card}") + logger.info(f"Trainer proc size: {trainer_proc_size}") + logger.info(f"Policy tensor parallel size: {policy_tp_size}") - @pytest.mark.asyncio - @requires_cuda - @pytest.mark.parametrize( - "use_dcp", [pytest.param(True, id="use_dcp"), pytest.param(False, id="no_dcp")] - ) - async def test_policy_update_tp(self, use_dcp): - """ - Test the weight synchronization process between RLTrainer and Policy. - - This test performs the following steps: - - Initialize trainer and push weights v0 (original huggingface ckpt) - - Step the trainer, setting all weights to zero and push weights v1 - - Load weights v0 and check the policy has all zero weights - - Load weights v1 and check the policy has all the weights back - """ - # test configs/paralleism - trainer_worker_size = 2 - policy_worker_size = 2 - tp_size = 2 - - if torch.cuda.device_count() < 2: - pytest.skip( - f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" - ) + logger.info("Downloading model checkpoint from HuggingFace Hub") + cached_dir = snapshot_download(repo_id=model_card) + logger.info("Finished downloading model checkpoint from HuggingFace Hub") await ts.initialize() - - trainer_cfg = self.default_trainer_cfg_tp() - trainer_cfg["use_dcp"] = use_dcp + services_policy_cfg = cfg.services.policy + services_policy_cfg.num_replicas = 1 + + services_trainer_cfg = cfg.services.trainer + services_trainer_cfg.num_replicas = 1 + + trainer_cfg = cfg.trainer + trainer_cfg.checkpoint = { + "enable": True, + "folder": "/tmp/saved_checkpoints", + "initial_load_path": cached_dir, + "initial_load_in_hf": True, + } + if use_dcp_override is not None: + trainer_cfg["use_dcp"] = use_dcp_override + logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}") with TemporaryDirectory(dir="/dev/shm/") as tmpdir: - if use_dcp: - trainer_cfg["dcp_path"] = tmpdir - - policy_config, service_config = get_configs( - worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model - ) + trainer_cfg["dcp_path"] = tmpdir policy, rl_trainer = await asyncio.gather( *[ - Policy.options(**asdict(service_config)).as_service( - **policy_config + Policy.options(**services_policy_cfg).as_service(**cfg.policy), + MockRLTrainer.options(**services_trainer_cfg).as_service( + **trainer_cfg ), - MockRLTrainer.options( - procs=trainer_worker_size, with_gpus=True, num_replicas=1 - ).as_service(**trainer_cfg), ] ) + # Main logic begins here v0 = uuid.uuid4().int v1 = v0 + 1 @@ -320,6 +236,7 @@ async def test_policy_update_tp(self, use_dcp): for errs in all_errs: for _, e in errs.items(): assert not e, f"Validation failed with exception: {e}" + # Reloading v0, getting back original weights await policy.update_weights.fanout(policy_version=v0) all_errs = await policy._test_validate_model_params.fanout(validate_fn) @@ -327,4 +244,5 @@ async def test_policy_update_tp(self, use_dcp): for _, e in errs.items(): assert not e, f"Validation failed with exception: {e}" + logger.info("✅ Weight sharding sanity check passed!") await ts.shutdown() diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py new file mode 100644 index 000000000..6a2e23fbf --- /dev/null +++ b/tests/unit_tests/test_torchstore_utils.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile +import unittest + +from pathlib import Path + +import pytest + +import torch +import torch.distributed.checkpoint as dcp +from forge.actors._torchstore_utils import DcpHandle + +ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( + r"ignore:.*torch.distributed" +) + + +class TestDcpHandle(unittest.TestCase): + def _prepare_dcp_handle(self, test_dir: str) -> tuple[str, DcpHandle]: + """Returns path to checkpoint and DcpHandle.""" + checkpoint_id = str(Path(test_dir) / "test_checkpoint_id") + state_dict = {"a": torch.rand(1, 1), "b": torch.rand(1, 1)} + metadata = dcp.save(checkpoint_id=checkpoint_id, state_dict=state_dict) + assert os.path.exists(checkpoint_id), "failed to set up test checkpoint" + return checkpoint_id, DcpHandle( + checkpoint_id=checkpoint_id, + metadata=metadata, + param_names=list(state_dict.keys()), + ) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_deletes(self): + with tempfile.TemporaryDirectory() as test_dir: + ckpt_path, handle = self._prepare_dcp_handle(test_dir) + handle.drop() + self.assertFalse(os.path.exists(ckpt_path)) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_sets_none(self): + with tempfile.TemporaryDirectory() as test_dir: + _, handle = self._prepare_dcp_handle(test_dir) + handle.drop() + self.assertEqual(handle.checkpoint_id, None) + self.assertEqual(handle.metadata, None) + self.assertEqual(handle.param_names, None) + + @ignore_torch_distributed_unitialized_warning + def test_dcp_handle_drop_sets_none_for_manifold(self): + with tempfile.TemporaryDirectory() as test_dir: + _, handle = self._prepare_dcp_handle(test_dir) + handle.checkpoint_id = "manifold://test_bucket/tree/test_path" + handle.drop() + self.assertEqual(handle.checkpoint_id, None) + self.assertEqual(handle.metadata, None) + self.assertEqual(handle.param_names, None)