Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
import pprint
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
Expand All @@ -16,6 +17,10 @@
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
Expand Down Expand Up @@ -239,6 +244,23 @@ async def pad_token(self):
return self._tokenizer.pad_token_id


async def drop_weights(version: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great function, but this is actually the kind of logic I don't care to see in the main.py file. Would it be possible to have this be part of torchstore / dcp itself (or a wrapper we write)? That way we can specify here "keep_last_n_weights".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

print(f"Dropping weights @ version {version}")
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
dcp_handle.drop()
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")


async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
Expand Down Expand Up @@ -362,6 +384,8 @@ async def continuous_training():
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.fanout(training_step)
await policy.update_weights.fanout(training_step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down?

Can you elaborate on what you meant by "the RL loop be still alive"?

if training_step >= 2:
await drop_weights(training_step - 1)

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.qwen3_1_7b --config apps/grpo/qwen3_1_7b.yaml
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

# Global configuration
group_size: 8
Expand Down
6 changes: 3 additions & 3 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dataset:

# Policy configuration
policy:
use_vllm_builtin_load: false
use_vllm_builtin_load: true
engine_config:
model: ${model}
tensor_parallel_size: 2
Expand All @@ -33,8 +33,8 @@ policy:

# Trainer configuration
trainer:
vllm_tp_DEPRECATED: ${policy.engine_config.tensor_parallel_size}
use_vllm_builtin_load: false
use_dcp: true
use_vllm_builtin_load: true
model:
name: qwen3
flavor: 8B
Expand Down
91 changes: 91 additions & 0 deletions apps/toy_metrics/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio

import logging
import time

from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.observability.metric_actors import setup_metric_logger
from forge.observability.metrics import record_metric, ReductionType

from monarch.actor import current_rank, endpoint

logging.basicConfig(level=logging.DEBUG)


class TrainActor(ForgeActor):
"""Example training actor that records loss metrics."""

@endpoint
async def train_step(self, step: int):
rank = current_rank().rank
value = rank * 1000 + 100 * step
print(f"[TRAIN] Rank {rank}: Step {step}, loss={value}")
record_metric("train/loss", value)


class GeneratorActor(ForgeActor):
"""Example generation actor that records token count metrics."""

@endpoint
async def generate_step(self, step: int, substep: int):
rank = current_rank().rank
value = rank * 1000 + step * 100 + substep * 10
print(f"[GEN] Rank {rank}: Step {step}.{substep}, tokens={value}")
record_metric("generate/tokens", value, ReductionType.SUM)


# Main
async def main():
"""Example demonstrating distributed metric logging with different backends."""
group = f"grpo_exp_{int(time.time())}"

# Config format: {backend_name: backend_config_dict}
# Each backend can specify reduce_across_ranks to control distributed logging behavior
config = {
"console": {"reduce_across_ranks": True},
"wandb": {
"project": "my_project",
"group": group,
"reduce_across_ranks": True,
# Only useful if NOT reduce_across_ranks.
"share_run_id": False, # Share run ID across ranks -- Not recommended.
},
}

service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
mlogger = await setup_metric_logger()

# Spawn services first (triggers registrations via provisioner hook)
trainer = await TrainActor.options(**service_config).as_service()
generator = await GeneratorActor.options(**service_config).as_service()

# Now init config on global (inits backends eagerly across fetchers)
await mlogger.init_backends.call_one(config)

for i in range(3):
print(f"\n=== Global Step {i} ===")
await trainer.train_step.fanout(i)
for sub in range(3):
await generator.generate_step.fanout(i, sub)
await mlogger.flush.call_one(i)

# shutdown
await mlogger.shutdown.call_one()

await asyncio.gather(
trainer.shutdown(),
generator.shutdown(),
)

await shutdown()


if __name__ == "__main__":
asyncio.run(main())
36 changes: 35 additions & 1 deletion src/forge/actors/_torchstore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,49 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import shutil
from dataclasses import dataclass

import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

KEY_DELIM = "."
DCP_WHOLE_STATE_TAG = "dcp_whole_state_dict"


@dataclass
class DcpHandle:
checkpoint_id: str = ""
checkpoint_id: str | None = None
metadata: DcpMeta | None = None
param_names: list[str] | None = None

def drop(self) -> None:
if self.checkpoint_id is None:
raise ValueError("Dropping a null DcpHandle")
if self.checkpoint_id.startswith("manifold://"):
# Probably don't need to delete the checkpoint if it's on manifold
logger.warning(
f"Skipping deletion of {self.checkpoint_id} since it's on manifold"
)
self.checkpoint_id = None
self.metadata = None
self.param_names = None
return

try:
shutil.rmtree(self.checkpoint_id, ignore_errors=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to suppress the errors here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. I was just thinking logging the error is fine and don't want to crash everything if delete is not successful. Let me know what you think.

logger.debug(f"Removed old weights at {self.checkpoint_id}")
except OSError as e:
logger.error(f"Error deleting {self.checkpoint_id}: {e}")
finally:
self.checkpoint_id = None
self.metadata = None
self.param_names = None


def load_tensor_from_dcp(handle: DcpHandle, param_name) -> torch.Tensor:
Expand All @@ -35,3 +65,7 @@ def get_param_key(policy_version: int, name: str) -> str:

def extract_param_name(key: str) -> str:
return KEY_DELIM.join(key.split(KEY_DELIM)[1:])


def get_dcp_whole_state_dict_key(policy_version: int) -> str:
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}"
47 changes: 24 additions & 23 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
from vllm.worker.worker_base import WorkerWrapperBase

from forge.actors._torchstore_utils import (
DcpHandle,
extract_param_name,
get_dcp_whole_state_dict_key,
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
Expand Down Expand Up @@ -481,8 +481,6 @@ class PolicyWorker(ForgeActor):
# TODO: remove this later since no plumbing exists to change this value.
# Also, whether to use dcp or not can be inferred from torchstore get() call.
use_dcp: bool = True
# Cache hf param names on first update call.
hf_param_names = []

# used for tesing purposes only
_test_prev_params = {}
Expand Down Expand Up @@ -560,28 +558,31 @@ async def update(self, version: int):
logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
logger.debug(f"{matching_keys=}")
if not self.hf_param_names:
self.hf_param_names = [extract_param_name(key) for key in matching_keys]
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
loaded_weights = set()
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
start = time.perf_counter()
for name in self.hf_param_names:
param_key = get_param_key(version, name)
tensor_or_handle = await ts.get(param_key)
if isinstance(tensor_or_handle, torch.Tensor):
param = tensor_or_handle
elif isinstance(tensor_or_handle, DcpHandle):
logger.info(f"Loading {name} from DCP with handle {tensor_or_handle}")
param = load_tensor_from_dcp(tensor_or_handle, name)
logger.info(f"Loaded {name} from DCP with handle {tensor_or_handle}")
else:
raise RuntimeError(
f"Unexpected type for {param_key}: {type(tensor_or_handle)}"
)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
# Entire state dict is stored in a single DCP handle
if dcp_whole_state_dict_key in matching_keys:
logger.info(
f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}"
)
dcp_handle = await ts.get(dcp_whole_state_dict_key)
hf_param_names = dcp_handle.param_names
for name in hf_param_names:
param = load_tensor_from_dcp(dcp_handle, name)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
else: # Load each parameter from torchstore directly without DCP
hf_param_names = [extract_param_name(key) for key in matching_keys]
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
logger.info(
Copy link
Member

@joecummings joecummings Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire weight update timing is already calculated at the policy update level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but it's different from each worker's update time. The entire updating time is basically the longest among the workers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, we might just want to remove the top-level logging time I suppose

f"[PolicyWorker::update] Updated {len(loaded_weights)} parameters, took {time.perf_counter() - start} seconds"
)
Expand Down
34 changes: 21 additions & 13 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors._torchstore_utils import DcpHandle, get_param_key
from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
Expand Down Expand Up @@ -328,11 +332,12 @@ async def _push_weights_DEPRECATED( # noqa: N802
@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
logger.info(f"Pushing weights for policy version {policy_version}")
if not self.use_vllm_builtin_load:
return await self._push_weights_DEPRECATED(
policy_version, self.vllm_tp_DEPRECATED
)

start_time = time.perf_counter()
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")

Expand All @@ -344,21 +349,24 @@ async def push_weights(self, policy_version: int) -> None:
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
if self.use_dcp:
# we could use dcp.save() to save the whole state dict,
# but I don't want too much deviation between the two code paths
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
dcp_id = f"{self.dcp_path}/{key}"
metadata = dcp.save(
checkpoint_id=dcp_id,
state_dict={name: param},
)
dcp_handle = DcpHandle(checkpoint_id=dcp_id, metadata=metadata)
await ts.put(key, dcp_handle)
key = get_dcp_whole_state_dict_key(policy_version)
dcp_id = f"{self.dcp_path}/{key}"
storage_writer = torch.distributed.checkpoint.FileSystemWriter(
dcp_id, single_file_per_rank=False, thread_count=8
)
metadata = dcp.save(storage_writer=storage_writer, state_dict=hf_state_dict)
dcp_handle = DcpHandle(
checkpoint_id=dcp_id,
metadata=metadata,
param_names=hf_state_dict.keys(),
)
await ts.put(key, dcp_handle)
else:
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)
end_time = time.perf_counter()
logger.info("Completed weights push in %.2f seconds", end_time - start_time)

@endpoint
async def cleanup(self) -> None:
Expand Down
8 changes: 1 addition & 7 deletions src/forge/controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .actor import ForgeActor
from .proc_mesh import get_proc_mesh, stop_proc_mesh

Expand All @@ -24,9 +23,4 @@ async def spawn_actors(
return actors


__all__ = [
"spawn_actors",
"stop_proc_mesh",
"get_proc_mesh",
"ForgeActor",
]
__all__ = ["spawn_actors", "stop_proc_mesh", "get_proc_mesh", "ForgeActor"]
Loading
Loading