Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,34 +81,8 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())
self._init_dist()
super().__init__(job_config)

def _init_dist(self):
"""Initializes torch distributed.

torchrun normally hands this, but we need to do it ourselves
in monarch for now.

We should consider putting this into ForgeActor, but having this
be explicit for now.

"""
env = {
"RANK": str(self._rank),
"LOCAL_RANK": str(self._rank),
"LOCAL_WORLD_SIZE": str(self._size),
"GROUP_RANK": str(self._size),
"GROUP_WORLD_SIZE": str(self._size),
"ROLE_RANK": str(self._rank),
"ROLE_WORLD_SIZE": str(self._size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self._size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
logger.info("env: {}".format(env))

async def setup_metric_logger(self):
"""Initialization happens in the main process. Here we just retrieve it"""
mlogger = await get_or_create_metric_logger()
Expand Down
Loading