diff --git a/apps/sft/main.py b/apps/sft/main.py index 93ba05eed..adddde0f3 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -28,9 +28,10 @@ from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer from forge.observability import get_or_create_metric_logger, record_metric, Reduce +from forge.observability.metric_actors import GlobalLoggingActor from forge.util.config import parse -from monarch.actor import current_rank, current_size, endpoint +from monarch.actor import current_rank, current_size, endpoint, get_or_spawn_controller from omegaconf import DictConfig, OmegaConf from torch import nn from torchdata.stateful_dataloader import StatefulDataLoader @@ -110,8 +111,12 @@ def _init_dist(self): logger.info("env: {}".format(env)) async def setup_metric_logger(self): - """Initialization happens in the main process. Here we just retrieve it""" - mlogger = await get_or_create_metric_logger() + """Retrieve the already-initialized metric logger from main process""" + + # The global logger was already created in main process. + # Use get_or_spawn_controller from monarch to get reference to it + # Get reference to the existing global logger (don't create new one) + mlogger = await get_or_spawn_controller("global_logger", GlobalLoggingActor) return mlogger def record_batch_metrics(self, data_metrics: list): @@ -123,6 +128,7 @@ def record_batch_metrics(self, data_metrics: list): @endpoint async def setup(self): self.train_dataloader = self.setup_data() + self.mlogger = await self.setup_metric_logger() # self.train_dataloader = self.setup_data(