Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
from forge.observability.metric_actors import GlobalLoggingActor
from forge.util.config import parse

from monarch.actor import current_rank, current_size, endpoint
from monarch.actor import current_rank, current_size, endpoint, get_or_spawn_controller
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torchdata.stateful_dataloader import StatefulDataLoader
Expand Down Expand Up @@ -110,8 +111,12 @@ def _init_dist(self):
logger.info("env: {}".format(env))

async def setup_metric_logger(self):
"""Initialization happens in the main process. Here we just retrieve it"""
mlogger = await get_or_create_metric_logger()
"""Retrieve the already-initialized metric logger from main process"""

# The global logger was already created in main process.
# Use get_or_spawn_controller from monarch to get reference to it
# Get reference to the existing global logger (don't create new one)
mlogger = await get_or_spawn_controller("global_logger", GlobalLoggingActor)
return mlogger

def record_batch_metrics(self, data_metrics: list):
Expand All @@ -123,6 +128,7 @@ def record_batch_metrics(self, data_metrics: list):
@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()

self.mlogger = await self.setup_metric_logger()

# self.train_dataloader = self.setup_data(
Expand Down
Loading