From 6d8c50b49e78da13adcc9699212cb9f11510f5cb Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Mon, 10 Nov 2025 20:59:15 -0800 Subject: [PATCH 1/2] Fix: Enable metric logging in distributed actors - Fixes AttributeError when calling get_or_create_metric_logger() in actors - Use get_or_spawn_controller() to get reference to global logger - Avoids this_proc() call that returns ProcMeshRef instead of ProcMesh - Enables WandB metric flushing from distributed training actors - Loss values now appear in WandB dashboard across all ranks --- apps/sft/main.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 93ba05eed..45ae255e5 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -28,9 +28,10 @@ from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer from forge.observability import get_or_create_metric_logger, record_metric, Reduce +from forge.observability.metric_actors import GlobalLoggingActor from forge.util.config import parse -from monarch.actor import current_rank, current_size, endpoint +from monarch.actor import current_rank, current_size, endpoint, get_or_spawn_controller from omegaconf import DictConfig, OmegaConf from torch import nn from torchdata.stateful_dataloader import StatefulDataLoader @@ -110,8 +111,12 @@ def _init_dist(self): logger.info("env: {}".format(env)) async def setup_metric_logger(self): - """Initialization happens in the main process. Here we just retrieve it""" - mlogger = await get_or_create_metric_logger() + """Retrieve the already-initialized metric logger from main process""" + + # The global logger was already created in main process. + # Use get_or_spawn_controller from monarch to get reference to it + # Get reference to the existing global logger (don't create new one) + mlogger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) return mlogger def record_batch_metrics(self, data_metrics: list): @@ -122,7 +127,9 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): + # Setup training data self.train_dataloader = self.setup_data() + self.mlogger = await self.setup_metric_logger() # self.train_dataloader = self.setup_data( From f8751680ea39da6e1ebdb591ae0c0d51747d376e Mon Sep 17 00:00:00 2001 From: Hossein Kavianihamedani Date: Mon, 10 Nov 2025 21:05:58 -0800 Subject: [PATCH 2/2] remove comment --- apps/sft/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 45ae255e5..adddde0f3 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -127,7 +127,6 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): - # Setup training data self.train_dataloader = self.setup_data() self.mlogger = await self.setup_metric_logger()