From 6e0dcdb5e919434e8cdf9fa3c96bb9519124c272 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Fri, 29 Aug 2025 10:40:16 -0700 Subject: [PATCH 1/3] skeleton code of ts integration --- src/forge/actors/trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..f2997dd96 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import asyncio import logging import math import os @@ -12,7 +13,14 @@ from dataclasses import dataclass, field, fields import torch +import torchtitan.experiments.forge.train_spec as forge_train_spec + +# from tqdm import tqdm + + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint +from torchstore._state_dict_utils import push_state_dict from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -30,8 +38,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -185,6 +191,18 @@ def train_step(self, batch) -> None: self.engine.lr_schedulers.step() self.current_step += 1 + + # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. + # TODOs: + # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() + # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. + # May need to replicate the same in this code path. + # 3. Integrate zero-overhead version of push_state_dict. + # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. + # 5. Unify CheckpointManager and TorchStore weights save control path. + push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") + # if self.current_step % self.train_config.val_every_n_steps == 0: + # self.validate() self.engine.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, From 15740b0f7e3cbb2c07fb990c88cc7f47d91028a8 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Mon, 1 Sep 2025 18:09:33 -0700 Subject: [PATCH 2/3] run instruction --- apps/vllm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 2d3c81ad9..9da85e631 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """To run: - +export HF_HUB_DISABLE_XET=1 python -m apps.vllm.main --guided-decoding --num-samples 3 """ From 326c1a3745e3e431cbab4ec9982e60abc113e2c3 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Mon, 1 Sep 2025 18:15:37 -0700 Subject: [PATCH 3/3] undo --- src/forge/actors/trainer.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f2997dd96..4232ca5ca 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. -import asyncio import logging import math import os @@ -13,14 +12,7 @@ from dataclasses import dataclass, field, fields import torch -import torchtitan.experiments.forge.train_spec as forge_train_spec - -# from tqdm import tqdm - - -from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint -from torchstore._state_dict_utils import push_state_dict from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -38,6 +30,8 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig +from forge.controller import ForgeActor + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -191,18 +185,6 @@ def train_step(self, batch) -> None: self.engine.lr_schedulers.step() self.current_step += 1 - - # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. - # TODOs: - # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() - # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. - # May need to replicate the same in this code path. - # 3. Integrate zero-overhead version of push_state_dict. - # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. - # 5. Unify CheckpointManager and TorchStore weights save control path. - push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") - # if self.current_step % self.train_config.val_every_n_steps == 0: - # self.validate() self.engine.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps,