-
Notifications
You must be signed in to change notification settings - Fork 30
fix dcp for new weight update #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf5bead
ef504d6
b594bf7
9ca7c51
f9681c4
516abf0
95114d1
c9434ea
ea97ba5
89a3c58
22d1be2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
|
|
||
| import asyncio | ||
| import pprint | ||
| import time | ||
| import uuid | ||
| from dataclasses import dataclass | ||
| from typing import Any, Callable | ||
|
|
@@ -16,6 +17,10 @@ | |
| import torch.nn.functional as F | ||
| import torchstore as ts | ||
| from datasets import load_dataset | ||
| from forge.actors._torchstore_utils import ( | ||
| get_dcp_whole_state_dict_key, | ||
| get_param_prefix, | ||
| ) | ||
| from forge.actors.policy import Policy | ||
| from forge.actors.reference_model import ReferenceModel | ||
| from forge.actors.replay_buffer import ReplayBuffer | ||
|
|
@@ -239,6 +244,23 @@ async def pad_token(self): | |
| return self._tokenizer.pad_token_id | ||
|
|
||
|
|
||
| async def drop_weights(version: int): | ||
| print(f"Dropping weights @ version {version}") | ||
| start_time = time.perf_counter() | ||
| prefix = get_param_prefix(version) | ||
| matching_keys = await ts.keys(prefix) | ||
| # TODO: once we have something like `get_meta()` in torchstore, we can just | ||
| # query the type of the object instead of relying on keys. | ||
| dcp_key = get_dcp_whole_state_dict_key(version) | ||
| if dcp_key in matching_keys: | ||
| dcp_handle = await ts.get(dcp_key) | ||
| dcp_handle.drop() | ||
| for key in matching_keys: | ||
| await ts.delete(key) | ||
| elapsed = time.perf_counter() - start_time | ||
| print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") | ||
|
|
||
|
|
||
| async def main(cfg: DictConfig): | ||
| """Main GRPO training loop with rollout and training processes.""" | ||
| group_size = cfg.group_size | ||
|
|
@@ -362,6 +384,8 @@ async def continuous_training(): | |
| mlogger.log("loss/training_step", loss, training_step) | ||
| await trainer.push_weights.fanout(training_step) | ||
| await policy.update_weights.fanout(training_step) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on what you meant by "the RL loop be still alive"? |
||
| if training_step >= 2: | ||
| await drop_weights(training_step - 1) | ||
|
|
||
| print("Starting GRPO training loops...") | ||
| # TODO: Start multiple rollouts once all serivces support it | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import asyncio | ||
|
|
||
| import logging | ||
| import time | ||
|
|
||
| from forge.controller.actor import ForgeActor | ||
| from forge.controller.provisioner import shutdown | ||
| from forge.observability.metric_actors import setup_metric_logger | ||
| from forge.observability.metrics import record_metric, ReductionType | ||
|
|
||
| from monarch.actor import current_rank, endpoint | ||
|
|
||
| logging.basicConfig(level=logging.DEBUG) | ||
|
|
||
|
|
||
| class TrainActor(ForgeActor): | ||
| """Example training actor that records loss metrics.""" | ||
|
|
||
| @endpoint | ||
| async def train_step(self, step: int): | ||
| rank = current_rank().rank | ||
| value = rank * 1000 + 100 * step | ||
| print(f"[TRAIN] Rank {rank}: Step {step}, loss={value}") | ||
| record_metric("train/loss", value) | ||
|
|
||
|
|
||
| class GeneratorActor(ForgeActor): | ||
| """Example generation actor that records token count metrics.""" | ||
|
|
||
| @endpoint | ||
| async def generate_step(self, step: int, substep: int): | ||
| rank = current_rank().rank | ||
| value = rank * 1000 + step * 100 + substep * 10 | ||
| print(f"[GEN] Rank {rank}: Step {step}.{substep}, tokens={value}") | ||
| record_metric("generate/tokens", value, ReductionType.SUM) | ||
|
|
||
|
|
||
| # Main | ||
| async def main(): | ||
| """Example demonstrating distributed metric logging with different backends.""" | ||
| group = f"grpo_exp_{int(time.time())}" | ||
|
|
||
| # Config format: {backend_name: backend_config_dict} | ||
| # Each backend can specify reduce_across_ranks to control distributed logging behavior | ||
| config = { | ||
| "console": {"reduce_across_ranks": True}, | ||
| "wandb": { | ||
| "project": "my_project", | ||
| "group": group, | ||
| "reduce_across_ranks": True, | ||
| # Only useful if NOT reduce_across_ranks. | ||
| "share_run_id": False, # Share run ID across ranks -- Not recommended. | ||
| }, | ||
| } | ||
|
|
||
| service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} | ||
| mlogger = await setup_metric_logger() | ||
|
|
||
| # Spawn services first (triggers registrations via provisioner hook) | ||
| trainer = await TrainActor.options(**service_config).as_service() | ||
| generator = await GeneratorActor.options(**service_config).as_service() | ||
|
|
||
| # Now init config on global (inits backends eagerly across fetchers) | ||
| await mlogger.init_backends.call_one(config) | ||
|
|
||
| for i in range(3): | ||
| print(f"\n=== Global Step {i} ===") | ||
| await trainer.train_step.fanout(i) | ||
| for sub in range(3): | ||
| await generator.generate_step.fanout(i, sub) | ||
| await mlogger.flush.call_one(i) | ||
|
|
||
| # shutdown | ||
| await mlogger.shutdown.call_one() | ||
|
|
||
| await asyncio.gather( | ||
| trainer.shutdown(), | ||
| generator.shutdown(), | ||
| ) | ||
|
|
||
| await shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,19 +3,49 @@ | |
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| import logging | ||
| import shutil | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| import torch.distributed.checkpoint as dcp | ||
| from torch.distributed.checkpoint.metadata import Metadata as DcpMeta | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.DEBUG) | ||
|
|
||
| KEY_DELIM = "." | ||
| DCP_WHOLE_STATE_TAG = "dcp_whole_state_dict" | ||
|
|
||
|
|
||
| @dataclass | ||
| class DcpHandle: | ||
| checkpoint_id: str = "" | ||
| checkpoint_id: str | None = None | ||
| metadata: DcpMeta | None = None | ||
| param_names: list[str] | None = None | ||
|
|
||
| def drop(self) -> None: | ||
| if self.checkpoint_id is None: | ||
| raise ValueError("Dropping a null DcpHandle") | ||
| if self.checkpoint_id.startswith("manifold://"): | ||
| # Probably don't need to delete the checkpoint if it's on manifold | ||
| logger.warning( | ||
| f"Skipping deletion of {self.checkpoint_id} since it's on manifold" | ||
| ) | ||
| self.checkpoint_id = None | ||
| self.metadata = None | ||
| self.param_names = None | ||
| return | ||
|
|
||
| try: | ||
| shutil.rmtree(self.checkpoint_id, ignore_errors=False) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we want to suppress the errors here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair. I was just thinking logging the error is fine and don't want to crash everything if delete is not successful. Let me know what you think. |
||
| logger.debug(f"Removed old weights at {self.checkpoint_id}") | ||
| except OSError as e: | ||
| logger.error(f"Error deleting {self.checkpoint_id}: {e}") | ||
| finally: | ||
| self.checkpoint_id = None | ||
| self.metadata = None | ||
| self.param_names = None | ||
|
|
||
|
|
||
| def load_tensor_from_dcp(handle: DcpHandle, param_name) -> torch.Tensor: | ||
|
|
@@ -35,3 +65,7 @@ def get_param_key(policy_version: int, name: str) -> str: | |
|
|
||
| def extract_param_name(key: str) -> str: | ||
| return KEY_DELIM.join(key.split(KEY_DELIM)[1:]) | ||
|
|
||
|
|
||
| def get_dcp_whole_state_dict_key(policy_version: int) -> str: | ||
| return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,8 +43,8 @@ | |
| from vllm.worker.worker_base import WorkerWrapperBase | ||
|
|
||
| from forge.actors._torchstore_utils import ( | ||
| DcpHandle, | ||
| extract_param_name, | ||
| get_dcp_whole_state_dict_key, | ||
| get_param_key, | ||
| get_param_prefix, | ||
| load_tensor_from_dcp, | ||
|
|
@@ -481,8 +481,6 @@ class PolicyWorker(ForgeActor): | |
| # TODO: remove this later since no plumbing exists to change this value. | ||
| # Also, whether to use dcp or not can be inferred from torchstore get() call. | ||
| use_dcp: bool = True | ||
| # Cache hf param names on first update call. | ||
| hf_param_names = [] | ||
|
|
||
| # used for tesing purposes only | ||
| _test_prev_params = {} | ||
|
|
@@ -560,28 +558,31 @@ async def update(self, version: int): | |
| logger.debug(f"{prefix=}") | ||
| matching_keys = await ts.keys(prefix) | ||
| logger.debug(f"{matching_keys=}") | ||
| if not self.hf_param_names: | ||
| self.hf_param_names = [extract_param_name(key) for key in matching_keys] | ||
| dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) | ||
| loaded_weights = set() | ||
| # We can't pass a generator since vllm load_weights is not async. | ||
| # Instead, we just call load_weights with one parameter at a time. | ||
| start = time.perf_counter() | ||
| for name in self.hf_param_names: | ||
| param_key = get_param_key(version, name) | ||
| tensor_or_handle = await ts.get(param_key) | ||
| if isinstance(tensor_or_handle, torch.Tensor): | ||
| param = tensor_or_handle | ||
| elif isinstance(tensor_or_handle, DcpHandle): | ||
| logger.info(f"Loading {name} from DCP with handle {tensor_or_handle}") | ||
| param = load_tensor_from_dcp(tensor_or_handle, name) | ||
| logger.info(f"Loaded {name} from DCP with handle {tensor_or_handle}") | ||
| else: | ||
| raise RuntimeError( | ||
| f"Unexpected type for {param_key}: {type(tensor_or_handle)}" | ||
| ) | ||
| loaded = model.load_weights([(name, param)]) | ||
| del param | ||
| loaded_weights.update(loaded) | ||
| # Entire state dict is stored in a single DCP handle | ||
| if dcp_whole_state_dict_key in matching_keys: | ||
| logger.info( | ||
| f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" | ||
| ) | ||
| dcp_handle = await ts.get(dcp_whole_state_dict_key) | ||
| hf_param_names = dcp_handle.param_names | ||
| for name in hf_param_names: | ||
| param = load_tensor_from_dcp(dcp_handle, name) | ||
| loaded = model.load_weights([(name, param)]) | ||
| del param | ||
| loaded_weights.update(loaded) | ||
| else: # Load each parameter from torchstore directly without DCP | ||
| hf_param_names = [extract_param_name(key) for key in matching_keys] | ||
| # We can't pass a generator since vllm load_weights is not async. | ||
| # Instead, we just call load_weights with one parameter at a time. | ||
| for name in hf_param_names: | ||
| param_key = get_param_key(version, name) | ||
| param = await ts.get(param_key) | ||
| loaded = model.load_weights([(name, param)]) | ||
| del param | ||
| loaded_weights.update(loaded) | ||
| logger.info( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The entire weight update timing is already calculated at the policy update level. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but it's different from each worker's update time. The entire updating time is basically the longest among the workers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eventually, we might just want to remove the top-level logging time I suppose |
||
| f"[PolicyWorker::update] Updated {len(loaded_weights)} parameters, took {time.perf_counter() - start} seconds" | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great function, but this is actually the kind of logic I don't care to see in the main.py file. Would it be possible to have this be part of torchstore / dcp itself (or a wrapper we write)? That way we can specify here "keep_last_n_weights".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair