From 6e0dcdb5e919434e8cdf9fa3c96bb9519124c272 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Fri, 29 Aug 2025 10:40:16 -0700 Subject: [PATCH 1/4] skeleton code of ts integration --- src/forge/actors/trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..f2997dd96 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import asyncio import logging import math import os @@ -12,7 +13,14 @@ from dataclasses import dataclass, field, fields import torch +import torchtitan.experiments.forge.train_spec as forge_train_spec + +# from tqdm import tqdm + + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint +from torchstore._state_dict_utils import push_state_dict from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -30,8 +38,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -185,6 +191,18 @@ def train_step(self, batch) -> None: self.engine.lr_schedulers.step() self.current_step += 1 + + # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. + # TODOs: + # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() + # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. + # May need to replicate the same in this code path. + # 3. Integrate zero-overhead version of push_state_dict. + # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. + # 5. Unify CheckpointManager and TorchStore weights save control path. + push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") + # if self.current_step % self.train_config.val_every_n_steps == 0: + # self.validate() self.engine.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, From 3e019d908b25bf362a55ddcf42b6730668b3fc34 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Sun, 31 Aug 2025 15:54:27 -0700 Subject: [PATCH 2/4] updating to match the code --- apps/rl/llama3_8b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index a6b278d0a..06ba15882 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -18,7 +18,7 @@ trainer: processes: scheduler: local # local | mast (not supported yet) num_hosts: 1 - num_gpus: 4 + with_gpus: 4 num_procs: 4 optimizer: @@ -68,7 +68,7 @@ replay_buffer: processes: scheduler: local # local | mast (not supported yet) num_hosts: 1 - num_gpus: 0 + with_gpus: 0 num_procs: 1 # policy: From 4d8a1547d9f4d712fc0d799e829ff094db6558c9 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Sun, 31 Aug 2025 16:13:23 -0700 Subject: [PATCH 3/4] removing some merge conflit edits --- src/forge/actors/trainer.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f2997dd96..4232ca5ca 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. -import asyncio import logging import math import os @@ -13,14 +12,7 @@ from dataclasses import dataclass, field, fields import torch -import torchtitan.experiments.forge.train_spec as forge_train_spec - -# from tqdm import tqdm - - -from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint -from torchstore._state_dict_utils import push_state_dict from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -38,6 +30,8 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig +from forge.controller import ForgeActor + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -191,18 +185,6 @@ def train_step(self, batch) -> None: self.engine.lr_schedulers.step() self.current_step += 1 - - # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. - # TODOs: - # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() - # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. - # May need to replicate the same in this code path. - # 3. Integrate zero-overhead version of push_state_dict. - # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. - # 5. Unify CheckpointManager and TorchStore weights save control path. - push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") - # if self.current_step % self.train_config.val_every_n_steps == 0: - # self.validate() self.engine.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, From 70d98613789ac53b3fcf91a383c9d0d2d609ec5d Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Sun, 31 Aug 2025 16:17:43 -0700 Subject: [PATCH 4/4] with_gpus is a bool config --- apps/rl/llama3_8b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index 06ba15882..ac4635575 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -18,7 +18,7 @@ trainer: processes: scheduler: local # local | mast (not supported yet) num_hosts: 1 - with_gpus: 4 + with_gpus: True num_procs: 4 optimizer: @@ -68,7 +68,7 @@ replay_buffer: processes: scheduler: local # local | mast (not supported yet) num_hosts: 1 - with_gpus: 0 + with_gpus: False num_procs: 1 # policy: