From bb0386ac7db668a11fdabf2882fbad357507f519 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 12:21:37 -0700 Subject: [PATCH 01/10] wandb metric logging --- apps/sft/llama3_8b.yaml | 8 +- apps/sft/main.py | 4 +- apps/sft_v2/llama3_8b.yaml | 8 +- apps/sft_v2/main.py | 5 +- src/forge/types.py | 8 +- src/forge/util/__init__.py | 15 +++ src/forge/util/logging.py | 147 ++++++++++++++++++++++ src/forge/util/metric_logging.py | 208 +++++++++++++++++++++++++++++++ 8 files changed, 392 insertions(+), 11 deletions(-) create mode 100644 src/forge/util/__init__.py create mode 100644 src/forge/util/logging.py create mode 100644 src/forge/util/metric_logging.py diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 573b401bd..8006b9534 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -2,10 +2,10 @@ # profiling: # enable_profiling: false -# metrics: -# log_freq: 10 -# enable_tensorboard: true -# save_tb_folder: "tb" +metrics: + logger: 'wandb' + log_dir: "metrics_log" + log_freq: 10 # TODO: required by torchtitan # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 diff --git a/apps/sft/main.py b/apps/sft/main.py index 9781dad5c..988095931 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -18,6 +18,7 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.util import get_metric_logger from omegaconf import DictConfig, OmegaConf from torch import nn @@ -60,7 +61,7 @@ def __init__(self, job_config: ForgeJobConfig): self.num_training_steps = job_config.training.steps self.gradient_accumulation_steps = 1 # Example value, adjust as needed super().__init__(job_config) - self.metric_logger = None # TODO: fix this + self.metric_logger = get_metric_logger(**job_config.metrics) def setup(self): self.train_dataloader = self.setup_data() @@ -185,6 +186,7 @@ def train_step(self, batch) -> None: loss = self.forward_backward(batch, labels) self.pbar.update(1) self.pbar.set_description(f"{self.current_step}|Loss: {loss}") + self.metric_logger.log("loss", loss, self.current_step) self.optimizers.step() self.lr_schedulers.step() diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index 1bb180faa..4f940c3e5 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -2,10 +2,10 @@ # profiling: # enable_profiling: false -# metrics: -# log_freq: 10 -# enable_tensorboard: true -# save_tb_folder: "tb" +metrics: + logger: 'wandb' + log_dir: "metrics_log" + log_freq: 10 # TODO: required by torchtitan # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 diff --git a/apps/sft_v2/main.py b/apps/sft_v2/main.py index 1c1f43c01..0426d0be3 100644 --- a/apps/sft_v2/main.py +++ b/apps/sft_v2/main.py @@ -29,6 +29,7 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.util import get_metric_logger from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf @@ -74,7 +75,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): def __init__(self, job_config: ForgeJobConfig): self.current_step = 0 self.num_training_steps = job_config.training.steps - self.metric_logger = None # TODO: fix this + self.metric_logger = get_metric_logger(**job_config.metrics) self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) @@ -238,6 +239,8 @@ def train_step(self, batch) -> None: logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) + self.metric_logger.log("loss", loss, self.current_step) + self.optimizers.step() self.lr_schedulers.step() diff --git a/src/forge/types.py b/src/forge/types.py index 08eb85976..ac0938d9b 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import Any, Literal, TypedDict +from typing import Any, Literal, TypedDict, Union + +from numpy import ndarray +from torch import Tensor class Message(TypedDict): @@ -98,3 +101,6 @@ class ProcessConfig: oncall: str = "torchtune" identity: str = "pytorch_distributed" image: str = "forge_workspace:latest" + + +Scalar = Union[Tensor, ndarray, int, float] diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py new file mode 100644 index 000000000..c9cf0bfd3 --- /dev/null +++ b/src/forge/util/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from .logging import deprecated, get_logger, log_once, log_rank_zero +from .metric_logging import get_metric_logger + +__all__ = [ + "deprecated", + "get_logger", + "log_once", + "log_rank_zero", + "get_metric_logger", +] diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py new file mode 100644 index 000000000..fcfb047a0 --- /dev/null +++ b/src/forge/util/logging.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import warnings +from functools import lru_cache, wraps +from typing import Callable, Optional, TypeVar + +from torch import distributed as dist + +T = TypeVar("T", bound=type) + + +def get_logger(level: Optional[str] = None) -> logging.Logger: + """ + Get a logger with a stream handler. + + Args: + level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. + + Example: + >>> logger = get_logger("INFO") + >>> logger.info("Hello world!") + INFO:torchtune.utils._logging:Hello world! + + Returns: + logging.Logger: The logger. + """ + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(logging.StreamHandler()) + if level is not None: + level = getattr(logging, level.upper()) + logger.setLevel(level) + return logger + + +def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: + """ + Logs a message only on rank zero. + + Args: + logger (logging.Logger): The logger. + msg (str): The warning message. + level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. + Defaults to ``logging.INFO``. + """ + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if rank != 0: + return + logger.log(level, msg, stacklevel=2) + + +@lru_cache(None) +def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: + """ + Logs a message only once. LRU cache is used to ensure a specific message is + logged only once, similar to how :func:`~warnings.warn` works when the ``once`` + rule is set via command-line or environment variable. + + Args: + logger (logging.Logger): The logger. + msg (str): The warning message. + level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. + Defaults to ``logging.INFO``. + """ + log_rank_zero(logger=logger, msg=msg, level=level) + + +def deprecated(msg: str = "") -> Callable[[T], T]: + """ + Decorator to mark an object as deprecated and print additional message. + + Args: + msg (str): additional information to print after warning. + + Returns: + Callable[[T], T]: the decorated object. + """ + + @lru_cache(maxsize=1) + def warn(obj): + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if rank != 0: + return + warnings.warn( + f"{obj.__name__} is deprecated and will be removed in future versions. " + + msg, + category=FutureWarning, + stacklevel=3, + ) + + def decorator(obj): + @wraps(obj) + def wrapper(*args, **kwargs): + warn(obj) + return obj(*args, **kwargs) + + return wrapper + + return decorator + + +def deprecate_parameter(param_name: str, msg: str = "") -> Callable[[T], T]: + """ + Decorator to mark a parameter as deprecated and print additional message. + + Args: + param_name (str): The name of the parameter. + msg (str): additional information to print after warning. + + Returns: + Callable[[T], T]: the decorated object. + """ + + @lru_cache(maxsize=1) + def warn(obj): + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if rank != 0: + return + warnings.warn( + f"{param_name} is deprecated for {obj.__name__} and will be removed in future versions. " + + msg, + category=FutureWarning, + stacklevel=3, + ) + + def decorator(obj): + sig = inspect.signature(obj) + + @wraps(obj) + def wrapper(*args, **kwargs): + # Check positional and kwargs + bound_args = sig.bind_partial(*args, **kwargs) + all_args = {**bound_args.arguments} + all_args.update(kwargs) + if param_name in all_args: + warn(obj) + return obj(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py new file mode 100644 index 000000000..13c59a776 --- /dev/null +++ b/src/forge/util/metric_logging.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +from abc import ABC, abstractmethod +from typing import Mapping, Optional + +import torch + +from forge.types import Scalar + + +def get_metric_logger(logger: str = "stdout", **log_config): + return METRIC_LOGGER_STR_TO_CLS[logger](**log_config) + + +class MetricLogger(ABC): + """Abstract metric logger. + + Args: + log_freq (int): calls to `log` and `log_dict` will be ignored if `step % log_freq != 0` + """ + + def __init__(self, log_freq: int): + self._log_freq = log_freq + self._step = None + + def set_step(self, step: int) -> None: + """Subsequent log calls will use this step number by default if not provided to the log call.""" + self._step = step + + def is_log_step(self, step: Optional[int] = None): + """Returns true if the current step is a logging step. + + Args: + step (int): current step. if not given, will use the one last provided via set_step() + """ + if step is None: + assert ( + self._step is not None + ), "`step` arg required if `set_step` has not been called." + step = self._step + return step % self._log_freq == 0 + + def log( + self, + name: str, + data: Scalar, + step: Optional[int] = None, + ) -> None: + """Log scalar data if this is a logging step. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record. if not given, will use the one last provided via set_step() + """ + if step is None: + assert ( + self._step is not None + ), "`step` arg required if `set_step` has not been called." + step = self._step + if step % self._log_freq == 0: + self._log(name, data, step) + + def log_dict( + self, payload: Mapping[str, Scalar], step: Optional[int] = None + ) -> None: + """Log multiple scalar values if this is a logging step. + + Args: + payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record. if not given, will use the one last provided via set_step() + """ + if step is None: + assert ( + self._step is not None + ), "`step` arg required if `set_step` has not been called." + step = self._step + if step % self._log_freq == 0: + self._log_dict(payload, step) + + @abstractmethod + def _log(self, name: str, data: Scalar, step: int) -> None: + """Log scalar data. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record + """ + pass + + @abstractmethod + def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + """Log multiple scalar values. + + Args: + payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record + """ + pass + + def close(self) -> None: + """ + Close log resource, flushing if necessary. + Logs should not be written after `close` is called. + """ + pass + + +class WandBLogger(MetricLogger): + """Logger for use w/ Weights and Biases application (https://wandb.ai/). + For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init. + + Args: + log_dir (Optional[str]): WandB log directory. + project (str): WandB project name. Default is `torchtune`. + entity (Optional[str]): WandB entity name. If you don't specify an entity, + the run will be sent to your default entity, which is usually your username. + group (Optional[str]): WandB group name for grouping runs together. If you don't + specify a group, the run will be logged as an individual experiment. + **kwargs: additional arguments to pass to wandb.init + + Example: + >>> from torchtune.training.metric_logging import WandBLogger + >>> logger = WandBLogger(log_dir="wandb", project="my_project", entity="my_entity", group="my_group") + >>> logger.log("my_metric", 1.0, 1) + >>> logger.log_dict({"my_metric": 1.0}, 1) + >>> logger.close() + + Raises: + ImportError: If ``wandb`` package is not installed. + + Note: + This logger requires the wandb package to be installed. + You can install it with `pip install wandb`. + In order to use the logger, you need to login to your WandB account. + You can do this by running `wandb login` in your terminal. + """ + + def __init__( + self, + log_dir: str, + project: str = "torchforge", + entity: Optional[str] = None, + group: Optional[str] = None, + **kwargs, + ): + try: + import wandb + except ImportError as e: + raise ImportError( + "``wandb`` package not found. Please install wandb using `pip install wandb` to use WandBLogger." + ) from e + self._wandb = wandb + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 0 + ) + if self._wandb.run is None and rank == 0: + # we check if wandb.init got called externally + run = self._wandb.init( + project=project, + entity=entity, + group=group, + dir=log_dir, + **kwargs, + ) + + if self._wandb.run: + self._wandb.run._label(repo="torchtune") + + # define default x-axis (for latest wandb versions) + if getattr(self._wandb, "define_metric", None): + self._wandb.define_metric("step") + self._wandb.define_metric("*", step_metric="step", step_sync=True) + + self.config_allow_val_change = kwargs.get("allow_val_change", False) + + def _log(self, name: str, data: Scalar, step: int) -> None: + if self._wandb.run: + self._wandb.log({name: data, "step": step}) + + def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + if self._wandb.run: + self._wandb.log({**payload, "step": step}) + + def __del__(self) -> None: + # extra check for when there is an import error + if hasattr(self, "_wandb") and self._wandb.run: + self._wandb.finish() + + def close(self) -> None: + if self._wandb.run: + self._wandb.finish() + + +METRIC_LOGGER_STR_TO_CLS = { + "wandb": WandBLogger, +} From 1c7483d4d6a9dc04f5e2a5de90ca36cc4d233c03 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 14:28:12 -0700 Subject: [PATCH 02/10] tensorboard, stdout, disk loggers --- apps/sft/llama3_8b.yaml | 5 +- apps/sft_v2/llama3_8b.yaml | 5 +- src/forge/util/metric_logging.py | 234 ++++++++++++++++++++++++++++--- 3 files changed, 219 insertions(+), 25 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 8006b9534..967885651 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -3,9 +3,10 @@ # enable_profiling: false metrics: - logger: 'wandb' + logger: "stdout" log_dir: "metrics_log" - log_freq: 10 + log_freq: + loss: 10 # TODO: required by torchtitan # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index 4f940c3e5..ad5892b3b 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -3,9 +3,10 @@ # enable_profiling: false metrics: - logger: 'wandb' + logger: "stdout" log_dir: "metrics_log" - log_freq: 10 + log_freq: + loss: 10 # TODO: required by torchtitan # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 13c59a776..349135609 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -3,8 +3,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json import os +import sys +import time from abc import ABC, abstractmethod +from pathlib import Path from typing import Mapping, Optional import torch @@ -20,10 +24,11 @@ class MetricLogger(ABC): """Abstract metric logger. Args: - log_freq (int): calls to `log` and `log_dict` will be ignored if `step % log_freq != 0` + log_freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0` """ - def __init__(self, log_freq: int): + def __init__(self, log_freq: Mapping[str, int]): self._log_freq = log_freq self._step = None @@ -31,10 +36,11 @@ def set_step(self, step: int) -> None: """Subsequent log calls will use this step number by default if not provided to the log call.""" self._step = step - def is_log_step(self, step: Optional[int] = None): + def is_log_step(self, name: str, step: Optional[int] = None): """Returns true if the current step is a logging step. Args: + name (str): metric name (for checking the log freq for this metric) step (int): current step. if not given, will use the one last provided via set_step() """ if step is None: @@ -42,7 +48,7 @@ def is_log_step(self, step: Optional[int] = None): self._step is not None ), "`step` arg required if `set_step` has not been called." step = self._step - return step % self._log_freq == 0 + return step % self._log_freq[name] == 0 def log( self, @@ -62,16 +68,16 @@ def log( self._step is not None ), "`step` arg required if `set_step` has not been called." step = self._step - if step % self._log_freq == 0: + if step % self._log_freq[name] == 0: self._log(name, data, step) def log_dict( - self, payload: Mapping[str, Scalar], step: Optional[int] = None + self, metrics: Mapping[str, Scalar], step: Optional[int] = None ) -> None: """Log multiple scalar values if this is a logging step. Args: - payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value step (int): step value to record. if not given, will use the one last provided via set_step() """ if step is None: @@ -79,8 +85,14 @@ def log_dict( self._step is not None ), "`step` arg required if `set_step` has not been called." step = self._step - if step % self._log_freq == 0: - self._log_dict(payload, step) + + log_step_metrics = { + name: value + for name, value in metrics.items() + if step % self._log_freq[name] == 0 + } + if log_step_metrics: + self._log_dict(log_step_metrics, step) @abstractmethod def _log(self, name: str, data: Scalar, step: int) -> None: @@ -103,14 +115,188 @@ def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: """ pass + def __del__(self) -> None: + self.close() + def close(self) -> None: """ Close log resource, flushing if necessary. + This will automatically be called via __del__ when the instance goes out of scope. Logs should not be written after `close` is called. """ pass +class StdoutLogger(MetricLogger): + """Logger to standard output.""" + + def _log(self, name: str, data: Scalar, step: int) -> None: + print(f"Step {step} | {name}:{data}") + + def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + print(f"Step {step} | ", end="") + for name, data in payload.items(): + print(f"{name}:{data} ", end="") + print("\n", end="") + + def close(self) -> None: + sys.stdout.flush() + + +class DiskLogger(MetricLogger): + """Logger to disk. + + Args: + log_dir (str): directory to store logs + output_fmt (str): format of the output file. Default: 'txt'. + Supported formats: 'txt', 'jsonl'. + filename (Optional[str]): optional filename to write logs to. + Default: None, in which case log_{unixtimestamp}.txt will be used. + **kwargs: additional arguments + + Warning: + This logger is not thread-safe. + + Note: + This logger creates a new file based on the current time. + """ + + def __init__( + self, + log_freq: Mapping[str, int], + log_dir: str, + output_fmt: str = "txt", + filename: Optional[str] = None, + **kwargs, + ): + super().__init__(log_freq) + + self.log_dir = Path(log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + self.output_fmt = output_fmt + assert self.output_fmt in [ + "txt", + "jsonl", + ], f"Unsupported output format: {self.output_fmt}. Supported formats: 'txt', 'jsonl'." + if not filename: + unix_timestamp = int(time.time()) + filename = f"log_{unix_timestamp}.{self.output_fmt}" + self._file_name = self.log_dir / filename + self._file = open(self._file_name, "a") + print(f"Writing logs to {self._file_name}") + + def path_to_log_file(self) -> Path: + return self._file_name + + def _log(self, name: str, data: Scalar, step: int) -> None: + if self.output_fmt == "txt": + self._file.write(f"Step {step} | {name}:{data}\n") + elif self.output_fmt == "jsonl": + json.dump( + {"step": step, name: data}, + self._file, + default=lambda x: x.tolist() if isinstance(x, torch.Tensor) else str(x), + ) + self._file.write("\n") + else: + raise ValueError( + f"Unsupported output format: {self.output_fmt}. Supported formats: 'txt', 'jsonl'." + ) + self._file.flush() + + def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + if self.output_fmt == "txt": + self._file.write(f"Step {step} | ") + for name, data in payload.items(): + self._file.write(f"{name}:{data} ") + elif self.output_fmt == "jsonl": + json.dump( + {"step": step} | {name: data for name, data in payload.items()}, + self._file, + default=lambda x: x.tolist() if isinstance(x, torch.Tensor) else str(x), + ) + else: + raise ValueError( + f"Unsupported output format: {self.output_fmt}. Supported formats: 'txt', 'jsonl'." + ) + self._file.write("\n") + self._file.flush() + + def close(self) -> None: + self._file.close() + + +class TensorBoardLogger(MetricLogger): + """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). + + Args: + log_dir (str): torch.TensorBoard log directory + organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current + run. Having sub-directories allows you to compare logs across runs. When TensorBoard is + passed a logdir at startup, it recursively walks the directory tree rooted at logdir looking for + subdirectories that contain tfevents data. Every time it encounters such a subdirectory, + it loads it as a new run, and the frontend will organize the data accordingly. + Recommended value is `True`. Run `tensorboard --logdir my_log_dir` to view the logs. + **kwargs: additional arguments + + Example: + >>> from torchtune.training.metric_logging import TensorBoardLogger + >>> logger = TensorBoardLogger(log_dir="my_log_dir") + >>> logger.log("my_metric", 1.0, 1) + >>> logger.log_dict({"my_metric": 1.0}, 1) + >>> logger.close() + + Note: + This utility requires the tensorboard package to be installed. + You can install it with `pip install tensorboard`. + In order to view TensorBoard logs, you need to run `tensorboard --logdir my_log_dir` in your terminal. + """ + + def __init__( + self, + log_freq: Mapping[str, int], + log_dir: str, + organize_logs: bool = True, + **kwargs, + ): + super().__init__(log_freq) + + from torch.utils.tensorboard import SummaryWriter + + self._writer: Optional[SummaryWriter] = None + _, self._rank = get_world_size_and_rank() + + # In case organize_logs is `True`, update log_dir to include a subdirectory for the + # current run + self.log_dir = ( + os.path.join(log_dir, f"run_{self._rank}_{time.time()}") + if organize_logs + else log_dir + ) + + # Initialize the log writer only if we're on rank 0. + if self._rank == 0: + self._writer = SummaryWriter(log_dir=self.log_dir) + + def log(self, name: str, data: Scalar, step: int) -> None: + if self._writer: + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + for name, data in payload.items(): + self.log(name, data, step) + + def __del__(self) -> None: + if self._writer: + self._writer.close() + self._writer = None + + def close(self) -> None: + if self._writer: + self._writer.close() + self._writer = None + + class WandBLogger(MetricLogger): """Logger for use w/ Weights and Biases application (https://wandb.ai/). For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init. @@ -143,12 +329,15 @@ class WandBLogger(MetricLogger): def __init__( self, + log_freq: Mapping[str, int], log_dir: str, - project: str = "torchforge", + project: str, entity: Optional[str] = None, group: Optional[str] = None, **kwargs, ): + super().__init__(log_freq) + try: import wandb except ImportError as e: @@ -160,11 +349,7 @@ def __init__( if not os.path.exists(log_dir): os.makedirs(log_dir) - rank = ( - torch.distributed.get_rank() - if torch.distributed.is_available() and torch.distributed.is_initialized() - else 0 - ) + rank = _get_rank() if self._wandb.run is None and rank == 0: # we check if wandb.init got called externally run = self._wandb.init( @@ -193,16 +378,23 @@ def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: if self._wandb.run: self._wandb.log({**payload, "step": step}) - def __del__(self) -> None: - # extra check for when there is an import error - if hasattr(self, "_wandb") and self._wandb.run: - self._wandb.finish() - def close(self) -> None: - if self._wandb.run: + if hasattr(self, "_wandb") and self._wandb.run: self._wandb.finish() +# TODO: replace with direct instantiation via a path to the class in the config METRIC_LOGGER_STR_TO_CLS = { + "stdout": StdoutLogger, + "disk": DiskLogger, + "tensorboard": TensorBoardLogger, "wandb": WandBLogger, } + + +def _get_rank(): + return ( + torch.distributed.get_rank() + if torch.distributed.is_available() and torch.distributed.is_initialized() + else 0 + ) From de31799db0915b3aea8a0bf5a7967d4c28ce337e Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 14:31:39 -0700 Subject: [PATCH 03/10] move base class to interfaces file --- src/forge/interfaces.py | 111 ++++++++++++++++++++++++++++- src/forge/util/metric_logging.py | 117 ++----------------------------- 2 files changed, 114 insertions(+), 114 deletions(-) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index f19f379cb..a0bedbb66 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -5,11 +5,11 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Mapping, Optional from monarch.actor import Actor, endpoint -from forge.types import Action, Message, Observation, State +from forge.types import Action, Message, Observation, Scalar, State class Transform(ABC): @@ -150,3 +150,110 @@ def tokenize_messages( tuple[list[int], list[bool]]: The list of token ids and the list of masks. """ pass + + +class MetricLogger(ABC): + """Abstract metric logger. + + Args: + log_freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0` + """ + + def __init__(self, log_freq: Mapping[str, int]): + self._log_freq = log_freq + self._step = None + + def set_step(self, step: int) -> None: + """Subsequent log calls will use this step number by default if not provided to the log call.""" + self._step = step + + def is_log_step(self, name: str, step: Optional[int] = None): + """Returns true if the current step is a logging step. + + Args: + name (str): metric name (for checking the log freq for this metric) + step (int): current step. if not given, will use the one last provided via set_step() + """ + if step is None: + assert ( + self._step is not None + ), "`step` arg required if `set_step` has not been called." + step = self._step + return step % self._log_freq[name] == 0 + + def log( + self, + name: str, + data: Scalar, + step: Optional[int] = None, + ) -> None: + """Log scalar data if this is a logging step. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record. if not given, will use the one last provided via set_step() + """ + if step is None: + assert ( + self._step is not None + ), "`step` arg required if `set_step` has not been called." + step = self._step + if step % self._log_freq[name] == 0: + self._log(name, data, step) + + def log_dict( + self, metrics: Mapping[str, Scalar], step: Optional[int] = None + ) -> None: + """Log multiple scalar values if this is a logging step. + + Args: + metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record. if not given, will use the one last provided via set_step() + """ + if step is None: + assert ( + self._step is not None + ), "`step` arg required if `set_step` has not been called." + step = self._step + + log_step_metrics = { + name: value + for name, value in metrics.items() + if step % self._log_freq[name] == 0 + } + if log_step_metrics: + self._log_dict(log_step_metrics, step) + + @abstractmethod + def _log(self, name: str, data: Scalar, step: int) -> None: + """Log scalar data. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record + """ + pass + + @abstractmethod + def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + """Log multiple scalar values. + + Args: + payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record + """ + pass + + def __del__(self) -> None: + self.close() + + def close(self) -> None: + """ + Close log resource, flushing if necessary. + This will automatically be called via __del__ when the instance goes out of scope. + Logs should not be written after `close` is called. + """ + pass diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 349135609..f24fb14d8 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -7,12 +7,12 @@ import os import sys import time -from abc import ABC, abstractmethod from pathlib import Path from typing import Mapping, Optional import torch +from forge.interfaces import MetricLogger from forge.types import Scalar @@ -20,113 +20,6 @@ def get_metric_logger(logger: str = "stdout", **log_config): return METRIC_LOGGER_STR_TO_CLS[logger](**log_config) -class MetricLogger(ABC): - """Abstract metric logger. - - Args: - log_freq (Mapping[str, int]): - calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0` - """ - - def __init__(self, log_freq: Mapping[str, int]): - self._log_freq = log_freq - self._step = None - - def set_step(self, step: int) -> None: - """Subsequent log calls will use this step number by default if not provided to the log call.""" - self._step = step - - def is_log_step(self, name: str, step: Optional[int] = None): - """Returns true if the current step is a logging step. - - Args: - name (str): metric name (for checking the log freq for this metric) - step (int): current step. if not given, will use the one last provided via set_step() - """ - if step is None: - assert ( - self._step is not None - ), "`step` arg required if `set_step` has not been called." - step = self._step - return step % self._log_freq[name] == 0 - - def log( - self, - name: str, - data: Scalar, - step: Optional[int] = None, - ) -> None: - """Log scalar data if this is a logging step. - - Args: - name (str): tag name used to group scalars - data (Scalar): scalar data to log - step (int): step value to record. if not given, will use the one last provided via set_step() - """ - if step is None: - assert ( - self._step is not None - ), "`step` arg required if `set_step` has not been called." - step = self._step - if step % self._log_freq[name] == 0: - self._log(name, data, step) - - def log_dict( - self, metrics: Mapping[str, Scalar], step: Optional[int] = None - ) -> None: - """Log multiple scalar values if this is a logging step. - - Args: - metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value - step (int): step value to record. if not given, will use the one last provided via set_step() - """ - if step is None: - assert ( - self._step is not None - ), "`step` arg required if `set_step` has not been called." - step = self._step - - log_step_metrics = { - name: value - for name, value in metrics.items() - if step % self._log_freq[name] == 0 - } - if log_step_metrics: - self._log_dict(log_step_metrics, step) - - @abstractmethod - def _log(self, name: str, data: Scalar, step: int) -> None: - """Log scalar data. - - Args: - name (str): tag name used to group scalars - data (Scalar): scalar data to log - step (int): step value to record - """ - pass - - @abstractmethod - def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: - """Log multiple scalar values. - - Args: - payload (Mapping[str, Scalar]): dictionary of tag name and scalar value - step (int): step value to record - """ - pass - - def __del__(self) -> None: - self.close() - - def close(self) -> None: - """ - Close log resource, flushing if necessary. - This will automatically be called via __del__ when the instance goes out of scope. - Logs should not be written after `close` is called. - """ - pass - - class StdoutLogger(MetricLogger): """Logger to standard output.""" @@ -264,21 +157,21 @@ def __init__( from torch.utils.tensorboard import SummaryWriter self._writer: Optional[SummaryWriter] = None - _, self._rank = get_world_size_and_rank() + rank = _get_rank() # In case organize_logs is `True`, update log_dir to include a subdirectory for the # current run self.log_dir = ( - os.path.join(log_dir, f"run_{self._rank}_{time.time()}") + os.path.join(log_dir, f"run_{rank}_{time.time()}") if organize_logs else log_dir ) # Initialize the log writer only if we're on rank 0. - if self._rank == 0: + if rank == 0: self._writer = SummaryWriter(log_dir=self.log_dir) - def log(self, name: str, data: Scalar, step: int) -> None: + def _log(self, name: str, data: Scalar, step: int) -> None: if self._writer: self._writer.add_scalar(name, data, global_step=step, new_style=True) From 120a5102cc277230766d270d72711056786b08c1 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 14:40:17 -0700 Subject: [PATCH 04/10] fix config --- apps/sft/llama3_8b.yaml | 1 - apps/sft_v2/llama3_8b.yaml | 1 - src/forge/util/metric_logging.py | 85 -------------------------------- 3 files changed, 87 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 967885651..1ba8b3d34 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -4,7 +4,6 @@ metrics: logger: "stdout" - log_dir: "metrics_log" log_freq: loss: 10 diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index ad5892b3b..477080fe6 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -4,7 +4,6 @@ metrics: logger: "stdout" - log_dir: "metrics_log" log_freq: loss: 10 diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index f24fb14d8..93ce13a0d 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -3,11 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import json import os import sys import time -from pathlib import Path from typing import Mapping, Optional import torch @@ -36,89 +34,6 @@ def close(self) -> None: sys.stdout.flush() -class DiskLogger(MetricLogger): - """Logger to disk. - - Args: - log_dir (str): directory to store logs - output_fmt (str): format of the output file. Default: 'txt'. - Supported formats: 'txt', 'jsonl'. - filename (Optional[str]): optional filename to write logs to. - Default: None, in which case log_{unixtimestamp}.txt will be used. - **kwargs: additional arguments - - Warning: - This logger is not thread-safe. - - Note: - This logger creates a new file based on the current time. - """ - - def __init__( - self, - log_freq: Mapping[str, int], - log_dir: str, - output_fmt: str = "txt", - filename: Optional[str] = None, - **kwargs, - ): - super().__init__(log_freq) - - self.log_dir = Path(log_dir) - self.log_dir.mkdir(parents=True, exist_ok=True) - self.output_fmt = output_fmt - assert self.output_fmt in [ - "txt", - "jsonl", - ], f"Unsupported output format: {self.output_fmt}. Supported formats: 'txt', 'jsonl'." - if not filename: - unix_timestamp = int(time.time()) - filename = f"log_{unix_timestamp}.{self.output_fmt}" - self._file_name = self.log_dir / filename - self._file = open(self._file_name, "a") - print(f"Writing logs to {self._file_name}") - - def path_to_log_file(self) -> Path: - return self._file_name - - def _log(self, name: str, data: Scalar, step: int) -> None: - if self.output_fmt == "txt": - self._file.write(f"Step {step} | {name}:{data}\n") - elif self.output_fmt == "jsonl": - json.dump( - {"step": step, name: data}, - self._file, - default=lambda x: x.tolist() if isinstance(x, torch.Tensor) else str(x), - ) - self._file.write("\n") - else: - raise ValueError( - f"Unsupported output format: {self.output_fmt}. Supported formats: 'txt', 'jsonl'." - ) - self._file.flush() - - def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: - if self.output_fmt == "txt": - self._file.write(f"Step {step} | ") - for name, data in payload.items(): - self._file.write(f"{name}:{data} ") - elif self.output_fmt == "jsonl": - json.dump( - {"step": step} | {name: data for name, data in payload.items()}, - self._file, - default=lambda x: x.tolist() if isinstance(x, torch.Tensor) else str(x), - ) - else: - raise ValueError( - f"Unsupported output format: {self.output_fmt}. Supported formats: 'txt', 'jsonl'." - ) - self._file.write("\n") - self._file.flush() - - def close(self) -> None: - self._file.close() - - class TensorBoardLogger(MetricLogger): """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). From a964fcbd1bd52a850f4608f0666a8af324d73cc6 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 15:06:22 -0700 Subject: [PATCH 05/10] fix --- src/forge/util/metric_logging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 93ce13a0d..de7e11344 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -194,7 +194,6 @@ def close(self) -> None: # TODO: replace with direct instantiation via a path to the class in the config METRIC_LOGGER_STR_TO_CLS = { "stdout": StdoutLogger, - "disk": DiskLogger, "tensorboard": TensorBoardLogger, "wandb": WandBLogger, } From f116007663a946b386242c10a7dea488e06b78b3 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 15:57:16 -0700 Subject: [PATCH 06/10] fix --- src/forge/util/metric_logging.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index de7e11344..415fda920 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -63,7 +63,7 @@ class TensorBoardLogger(MetricLogger): def __init__( self, log_freq: Mapping[str, int], - log_dir: str, + log_dir: str = "metrics_log", organize_logs: bool = True, **kwargs, ): @@ -138,8 +138,8 @@ class WandBLogger(MetricLogger): def __init__( self, log_freq: Mapping[str, int], - log_dir: str, project: str, + log_dir: str = "metrics_log", entity: Optional[str] = None, group: Optional[str] = None, **kwargs, @@ -169,14 +169,11 @@ def __init__( ) if self._wandb.run: - self._wandb.run._label(repo="torchtune") + # define default x-axis (for latest wandb versions) + if getattr(self._wandb, "define_metric", None): + self._wandb.define_metric("step") + self._wandb.define_metric("*", step_metric="step", step_sync=True) - # define default x-axis (for latest wandb versions) - if getattr(self._wandb, "define_metric", None): - self._wandb.define_metric("step") - self._wandb.define_metric("*", step_metric="step", step_sync=True) - - self.config_allow_val_change = kwargs.get("allow_val_change", False) def _log(self, name: str, data: Scalar, step: int) -> None: if self._wandb.run: From 096bd20c02d2f2026d9e4e2fce8391ef2c6ba1c0 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Aug 2025 16:00:37 -0700 Subject: [PATCH 07/10] fmt --- src/forge/util/metric_logging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 415fda920..4db2f5f37 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -174,7 +174,6 @@ def __init__( self._wandb.define_metric("step") self._wandb.define_metric("*", step_metric="step", step_sync=True) - def _log(self, name: str, data: Scalar, step: int) -> None: if self._wandb.run: self._wandb.log({name: data, "step": step}) From e6ed572d3cc5430c64c031da6f52ab9e6156e253 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Wed, 20 Aug 2025 12:00:19 -0700 Subject: [PATCH 08/10] addressing comments --- apps/sft/llama3_8b.yaml | 5 +- apps/sft_v2/llama3_8b.yaml | 4 +- src/forge/interfaces.py | 85 ++++------------------- src/forge/types.py | 5 +- src/forge/util/__init__.py | 5 +- src/forge/util/distributed.py | 20 ++++++ src/forge/util/logging.py | 82 +---------------------- src/forge/util/metric_logging.py | 111 +++++++++++++++++++++---------- 8 files changed, 120 insertions(+), 197 deletions(-) create mode 100644 src/forge/util/distributed.py diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 1ba8b3d34..afd182e66 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -3,8 +3,9 @@ # enable_profiling: false metrics: - logger: "stdout" - log_freq: + logger: wandb + project: dsafjkdsafjlskdjf + freq: loss: 10 # TODO: required by torchtitan diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index 477080fe6..1e21373fc 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -3,8 +3,8 @@ # enable_profiling: false metrics: - logger: "stdout" - log_freq: + logger: stdout + freq: loss: 10 # TODO: required by torchtitan diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index a0bedbb66..f11bffd05 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional +from typing import Any, Mapping from monarch.actor import Actor, endpoint @@ -153,82 +153,21 @@ def tokenize_messages( class MetricLogger(ABC): - """Abstract metric logger. + """Abstract metric logger.""" - Args: - log_freq (Mapping[str, int]): - calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0` - """ - - def __init__(self, log_freq: Mapping[str, int]): - self._log_freq = log_freq - self._step = None - - def set_step(self, step: int) -> None: - """Subsequent log calls will use this step number by default if not provided to the log call.""" - self._step = step - - def is_log_step(self, name: str, step: Optional[int] = None): + @abstractmethod + def is_log_step(self, name: str, step: int) -> bool: """Returns true if the current step is a logging step. Args: - name (str): metric name (for checking the log freq for this metric) - step (int): current step. if not given, will use the one last provided via set_step() - """ - if step is None: - assert ( - self._step is not None - ), "`step` arg required if `set_step` has not been called." - step = self._step - return step % self._log_freq[name] == 0 - - def log( - self, - name: str, - data: Scalar, - step: Optional[int] = None, - ) -> None: - """Log scalar data if this is a logging step. - - Args: - name (str): tag name used to group scalars - data (Scalar): scalar data to log - step (int): step value to record. if not given, will use the one last provided via set_step() + name (str): metric name (for checking the freq for this metric) + step (int): current step """ - if step is None: - assert ( - self._step is not None - ), "`step` arg required if `set_step` has not been called." - step = self._step - if step % self._log_freq[name] == 0: - self._log(name, data, step) - - def log_dict( - self, metrics: Mapping[str, Scalar], step: Optional[int] = None - ) -> None: - """Log multiple scalar values if this is a logging step. - - Args: - metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value - step (int): step value to record. if not given, will use the one last provided via set_step() - """ - if step is None: - assert ( - self._step is not None - ), "`step` arg required if `set_step` has not been called." - step = self._step - - log_step_metrics = { - name: value - for name, value in metrics.items() - if step % self._log_freq[name] == 0 - } - if log_step_metrics: - self._log_dict(log_step_metrics, step) + pass @abstractmethod - def _log(self, name: str, data: Scalar, step: int) -> None: - """Log scalar data. + def log(self, name: str, data: Scalar, step: int) -> None: + """Log scalar data if this is a logging step. Args: name (str): tag name used to group scalars @@ -238,11 +177,11 @@ def _log(self, name: str, data: Scalar, step: int) -> None: pass @abstractmethod - def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: - """Log multiple scalar values. + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log multiple scalar values if this is a logging step. Args: - payload (Mapping[str, Scalar]): dictionary of tag name and scalar value + metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value step (int): step value to record """ pass diff --git a/src/forge/types.py b/src/forge/types.py index ac0938d9b..ce79cdbe3 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -7,9 +7,6 @@ from dataclasses import dataclass, field from typing import Any, Literal, TypedDict, Union -from numpy import ndarray -from torch import Tensor - class Message(TypedDict): role: str @@ -103,4 +100,4 @@ class ProcessConfig: image: str = "forge_workspace:latest" -Scalar = Union[Tensor, ndarray, int, float] +Scalar = Union[int, float] diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py index c9cf0bfd3..5fb03b0f9 100644 --- a/src/forge/util/__init__.py +++ b/src/forge/util/__init__.py @@ -3,11 +3,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .logging import deprecated, get_logger, log_once, log_rank_zero +from .distributed import get_world_size_and_rank +from .logging import get_logger, log_once, log_rank_zero from .metric_logging import get_metric_logger __all__ = [ - "deprecated", + "get_world_size_and_rank", "get_logger", "log_once", "log_rank_zero", diff --git a/src/forge/util/distributed.py b/src/forge/util/distributed.py new file mode 100644 index 000000000..b32be7291 --- /dev/null +++ b/src/forge/util/distributed.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def get_world_size_and_rank() -> tuple[int, int]: + """Function that gets the current world size (aka total number + of ranks) and rank number of the current process in the default process group. + + Returns: + tuple[int, int]: world size, rank + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_world_size(), torch.distributed.get_rank() + else: + return 1, 0 diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py index fcfb047a0..e53218ccd 100644 --- a/src/forge/util/logging.py +++ b/src/forge/util/logging.py @@ -4,11 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import inspect import logging -import warnings -from functools import lru_cache, wraps -from typing import Callable, Optional, TypeVar +from functools import lru_cache +from typing import Optional, TypeVar from torch import distributed as dist @@ -69,79 +67,3 @@ def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> Non Defaults to ``logging.INFO``. """ log_rank_zero(logger=logger, msg=msg, level=level) - - -def deprecated(msg: str = "") -> Callable[[T], T]: - """ - Decorator to mark an object as deprecated and print additional message. - - Args: - msg (str): additional information to print after warning. - - Returns: - Callable[[T], T]: the decorated object. - """ - - @lru_cache(maxsize=1) - def warn(obj): - rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 - if rank != 0: - return - warnings.warn( - f"{obj.__name__} is deprecated and will be removed in future versions. " - + msg, - category=FutureWarning, - stacklevel=3, - ) - - def decorator(obj): - @wraps(obj) - def wrapper(*args, **kwargs): - warn(obj) - return obj(*args, **kwargs) - - return wrapper - - return decorator - - -def deprecate_parameter(param_name: str, msg: str = "") -> Callable[[T], T]: - """ - Decorator to mark a parameter as deprecated and print additional message. - - Args: - param_name (str): The name of the parameter. - msg (str): additional information to print after warning. - - Returns: - Callable[[T], T]: the decorated object. - """ - - @lru_cache(maxsize=1) - def warn(obj): - rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 - if rank != 0: - return - warnings.warn( - f"{param_name} is deprecated for {obj.__name__} and will be removed in future versions. " - + msg, - category=FutureWarning, - stacklevel=3, - ) - - def decorator(obj): - sig = inspect.signature(obj) - - @wraps(obj) - def wrapper(*args, **kwargs): - # Check positional and kwargs - bound_args = sig.bind_partial(*args, **kwargs) - all_args = {**bound_args.arguments} - all_args.update(kwargs) - if param_name in all_args: - warn(obj) - return obj(*args, **kwargs) - - return wrapper - - return decorator diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 4db2f5f37..8cfb552f3 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -8,10 +8,9 @@ import time from typing import Mapping, Optional -import torch - from forge.interfaces import MetricLogger from forge.types import Scalar +from forge.util.distributed import get_world_size_and_rank def get_metric_logger(logger: str = "stdout", **log_config): @@ -19,14 +18,41 @@ def get_metric_logger(logger: str = "stdout", **log_config): class StdoutLogger(MetricLogger): - """Logger to standard output.""" + """Logger to standard output. + + Args: + freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` + """ + + def __init__(self, freq: Mapping[str, int]): + self._freq = freq + + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. - def _log(self, name: str, data: Scalar, step: int) -> None: + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + return step % self._freq[name] == 0 + + def log(self, name: str, data: Scalar, step: int) -> None: + if not self.is_log_step(name, step): + return print(f"Step {step} | {name}:{data}") - def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + log_step_metrics = { + name: value + for name, value in metrics.items() + if self.is_log_step(name, step) + } + if not log_step_metrics: + return + print(f"Step {step} | ", end="") - for name, data in payload.items(): + for name, data in log_step_metrics.items(): print(f"{name}:{data} ", end="") print("\n", end="") @@ -38,6 +64,8 @@ class TensorBoardLogger(MetricLogger): """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). Args: + freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` log_dir (str): torch.TensorBoard log directory organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current run. Having sub-directories allows you to compare logs across runs. When TensorBoard is @@ -62,17 +90,16 @@ class TensorBoardLogger(MetricLogger): def __init__( self, - log_freq: Mapping[str, int], + freq: Mapping[str, int], log_dir: str = "metrics_log", organize_logs: bool = True, **kwargs, ): - super().__init__(log_freq) - from torch.utils.tensorboard import SummaryWriter + self._freq = freq self._writer: Optional[SummaryWriter] = None - rank = _get_rank() + _, rank = get_world_size_and_rank() # In case organize_logs is `True`, update log_dir to include a subdirectory for the # current run @@ -86,18 +113,23 @@ def __init__( if rank == 0: self._writer = SummaryWriter(log_dir=self.log_dir) - def _log(self, name: str, data: Scalar, step: int) -> None: - if self._writer: - self._writer.add_scalar(name, data, global_step=step, new_style=True) + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. - def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: - for name, data in payload.items(): - self.log(name, data, step) + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + return step % self._freq[name] == 0 - def __del__(self) -> None: + def log(self, name: str, data: Scalar, step: int) -> None: if self._writer: - self._writer.close() - self._writer = None + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + for name, data in metrics.items(): + if self.is_log_step(name, step): + self.log(name, data, step) def close(self) -> None: if self._writer: @@ -110,6 +142,8 @@ class WandBLogger(MetricLogger): For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init. Args: + freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` log_dir (Optional[str]): WandB log directory. project (str): WandB project name. Default is `torchtune`. entity (Optional[str]): WandB entity name. If you don't specify an entity, @@ -137,14 +171,14 @@ class WandBLogger(MetricLogger): def __init__( self, - log_freq: Mapping[str, int], + freq: Mapping[str, int], project: str, log_dir: str = "metrics_log", entity: Optional[str] = None, group: Optional[str] = None, **kwargs, ): - super().__init__(log_freq) + self._freq = freq try: import wandb @@ -157,7 +191,7 @@ def __init__( if not os.path.exists(log_dir): os.makedirs(log_dir) - rank = _get_rank() + _, rank = get_world_size_and_rank() if self._wandb.run is None and rank == 0: # we check if wandb.init got called externally run = self._wandb.init( @@ -174,13 +208,30 @@ def __init__( self._wandb.define_metric("step") self._wandb.define_metric("*", step_metric="step", step_sync=True) - def _log(self, name: str, data: Scalar, step: int) -> None: - if self._wandb.run: + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. + + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + return step % self._freq[name] == 0 + + def log(self, name: str, data: Scalar, step: int) -> None: + if self._wandb.run and self.is_log_step(name, step): self._wandb.log({name: data, "step": step}) - def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + log_step_metrics = { + name: value + for name, value in metrics.items() + if self.is_log_step(name, step) + } + if not log_step_metrics: + return + if self._wandb.run: - self._wandb.log({**payload, "step": step}) + self._wandb.log({**metrics, "step": step}) def close(self) -> None: if hasattr(self, "_wandb") and self._wandb.run: @@ -193,11 +244,3 @@ def close(self) -> None: "tensorboard": TensorBoardLogger, "wandb": WandBLogger, } - - -def _get_rank(): - return ( - torch.distributed.get_rank() - if torch.distributed.is_available() and torch.distributed.is_initialized() - else 0 - ) From b1dd753eb1504b2d54cb66c0728c7b3ee1d405c2 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Wed, 20 Aug 2025 13:33:51 -0700 Subject: [PATCH 09/10] reverting sft_v2 changes --- apps/sft/main.py | 2 +- apps/sft_v2/llama3_8b.yaml | 8 ++++---- apps/sft_v2/main.py | 5 +---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 988095931..cabf43abd 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -186,7 +186,7 @@ def train_step(self, batch) -> None: loss = self.forward_backward(batch, labels) self.pbar.update(1) self.pbar.set_description(f"{self.current_step}|Loss: {loss}") - self.metric_logger.log("loss", loss, self.current_step) + self.metric_logger.log("loss", loss.item(), self.current_step) self.optimizers.step() self.lr_schedulers.step() diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index 1e21373fc..1bb180faa 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -2,10 +2,10 @@ # profiling: # enable_profiling: false -metrics: - logger: stdout - freq: - loss: 10 +# metrics: +# log_freq: 10 +# enable_tensorboard: true +# save_tb_folder: "tb" # TODO: required by torchtitan # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 diff --git a/apps/sft_v2/main.py b/apps/sft_v2/main.py index 0426d0be3..1c1f43c01 100644 --- a/apps/sft_v2/main.py +++ b/apps/sft_v2/main.py @@ -29,7 +29,6 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer -from forge.util import get_metric_logger from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf @@ -75,7 +74,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): def __init__(self, job_config: ForgeJobConfig): self.current_step = 0 self.num_training_steps = job_config.training.steps - self.metric_logger = get_metric_logger(**job_config.metrics) + self.metric_logger = None # TODO: fix this self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) @@ -239,8 +238,6 @@ def train_step(self, batch) -> None: logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) - self.metric_logger.log("loss", loss, self.current_step) - self.optimizers.step() self.lr_schedulers.step() From 903204aa42aa9e88176a22eae9a66d7dd3bfc573 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 25 Aug 2025 18:11:39 -0700 Subject: [PATCH 10/10] fixing docstrings --- apps/sft/llama3_8b.yaml | 3 +- src/forge/util/metric_logging.py | 47 +++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index afd182e66..72c2fe1ec 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -3,8 +3,7 @@ # enable_profiling: false metrics: - logger: wandb - project: dsafjkdsafjlskdjf + logger: tensorboard freq: loss: 10 diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 8cfb552f3..75790c813 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -38,11 +38,24 @@ def is_log_step(self, name: str, step: int) -> bool: return step % self._freq[name] == 0 def log(self, name: str, data: Scalar, step: int) -> None: + """Log the metric if it is a logging step. + + Args: + name (str): metric name + data (Scalar): metric value + step (int): current step + """ if not self.is_log_step(name, step): return print(f"Step {step} | {name}:{data}") def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log the metrics for which this is currently a logging step. + + Args: + metrics (Mapping[str, Scalar]): dict of metric names and values + step (int): current step + """ log_step_metrics = { name: value for name, value in metrics.items() @@ -76,8 +89,8 @@ class TensorBoardLogger(MetricLogger): **kwargs: additional arguments Example: - >>> from torchtune.training.metric_logging import TensorBoardLogger - >>> logger = TensorBoardLogger(log_dir="my_log_dir") + >>> from forge.util.metric_logging import TensorBoardLogger + >>> logger = TensorBoardLogger(freq={"loss": 10}, log_dir="my_log_dir") >>> logger.log("my_metric", 1.0, 1) >>> logger.log_dict({"my_metric": 1.0}, 1) >>> logger.close() @@ -123,10 +136,23 @@ def is_log_step(self, name: str, step: int) -> bool: return step % self._freq[name] == 0 def log(self, name: str, data: Scalar, step: int) -> None: + """Log the metric if it is a logging step. + + Args: + name (str): metric name + data (Scalar): metric value + step (int): current step + """ if self._writer: self._writer.add_scalar(name, data, global_step=step, new_style=True) def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log the metrics for which this is currently a logging step. + + Args: + metrics (Mapping[str, Scalar]): dict of metric names and values + step (int): current step + """ for name, data in metrics.items(): if self.is_log_step(name, step): self.log(name, data, step) @@ -153,8 +179,8 @@ class WandBLogger(MetricLogger): **kwargs: additional arguments to pass to wandb.init Example: - >>> from torchtune.training.metric_logging import WandBLogger - >>> logger = WandBLogger(log_dir="wandb", project="my_project", entity="my_entity", group="my_group") + >>> from forge.util.metric_logging import WandBLogger + >>> logger = WandBLogger(freq={"loss": 10}, log_dir="wandb", project="my_project") >>> logger.log("my_metric", 1.0, 1) >>> logger.log_dict({"my_metric": 1.0}, 1) >>> logger.close() @@ -218,10 +244,23 @@ def is_log_step(self, name: str, step: int) -> bool: return step % self._freq[name] == 0 def log(self, name: str, data: Scalar, step: int) -> None: + """Log the metric if it is a logging step. + + Args: + name (str): metric name + data (Scalar): metric value + step (int): current step + """ if self._wandb.run and self.is_log_step(name, step): self._wandb.log({name: data, "step": step}) def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log the metrics for which this is currently a logging step. + + Args: + metrics (Mapping[str, Scalar]): dict of metric names and values + step (int): current step + """ log_step_metrics = { name: value for name, value in metrics.items()