-
Notifications
You must be signed in to change notification settings - Fork 24
fix dcp for new weight update #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
training_step += 1 | ||
mlogger.log("loss/training_step", loss, training_step) | ||
await trainer.push_weights.fanout(training_step) | ||
await policy.update_weights.fanout(training_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down?
Can you elaborate on what you meant by "the RL loop be still alive"?
* metric logger simple example * it works * delete old files * refactoring + docstrings * docstring * comments * update method name * no circular import * update command * update arg name * move metric actor out of asyncio lock * fix deregister * lint * docstring * fix result extraction and add logger shutdown * fix shutdown order * simplification + docstrings * bug fix + register if respawn * it works * use procmesh as key * docstring * remove protected imports * create get_metric_logger * call became fanout * upstream changes --------- Co-authored-by: Felipe Mello <felipemello@fb.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one comment, but otherwise very helpful
return self._tokenizer.pad_token_id | ||
|
||
|
||
async def drop_weights(version: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great function, but this is actually the kind of logic I don't care to see in the main.py file. Would it be possible to have this be part of torchstore / dcp itself (or a wrapper we write)? That way we can specify here "keep_last_n_weights".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair
…model config (for Titan + vLLM) (#241) * commit * flag * format * nit * nit
nit: Can you format the clickable links in the PR description? Copilot borked it |
done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments and questions, but overall LGTM
self.param_names = None | ||
return | ||
|
||
import shutil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can import this at the top
import shutil | ||
|
||
try: | ||
shutil.rmtree(self.checkpoint_id, ignore_errors=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want to suppress the errors here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair. I was just thinking logging the error is fine and don't want to crash everything if delete is not successful. Let me know what you think.
loaded = model.load_weights([(name, param)]) | ||
del param | ||
loaded_weights.update(loaded) | ||
logger.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The entire weight update timing is already calculated at the policy update level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but it's different from each worker's update time. The entire updating time is basically the longest among the workers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, we might just want to remove the top-level logging time I suppose
we'll need #255 to land |
Copilot generated summary follows:
This pull request introduces significant improvements to weight management in the GRPO training pipeline, focusing on more efficient handling, saving, and cleanup of model weights using torchstore and distributed checkpoints (DCP). The main changes include adding a mechanism to drop old weights, refactoring how weights are pushed and loaded (favoring DCP whole state dicts), and enhancing robustness with new utility functions and tests.
Weight Management Improvements
drop_weights
async function inapps/grpo/main.py
to efficiently delete old model weights after each training step, preventing unnecessary storage growth. This function uses new utilities to locate and drop both DCP and individual parameter keys. ([1], [2])PolicyWorker.update
to prefer loading the entire state dict from a single DCP handle when available, falling back to individual parameters otherwise. This streamlines weight updates and reduces complexity. (src/forge/actors/policy.pyL563-R582)Trainer.push_weights
to save the whole state dict as a single DCP handle whenuse_dcp
is enabled, improving performance and consistency. (src/forge/actors/trainer.pyL347-R369)Utility and Configuration Enhancements
DcpHandle
class insrc/forge/actors/_torchstore_utils.py
, including a robustdrop()
method to safely delete checkpoints and handle manifold storage cases. ([1], [2])Testing and Reliability
DcpHandle.drop()
method to ensure proper deletion and cleanup behavior, including edge cases for manifold storage. (tests/unit_tests/test_torchstore_utils.pyR1-R61)These changes collectively make the training pipeline more efficient, reliable, and easier to maintain by improving how model weights are stored, loaded, and cleaned up.