Skip to content

Conversation

casteryh
Copy link
Contributor

@casteryh casteryh commented Sep 27, 2025

Copilot generated summary follows:

This pull request introduces significant improvements to weight management in the GRPO training pipeline, focusing on more efficient handling, saving, and cleanup of model weights using torchstore and distributed checkpoints (DCP). The main changes include adding a mechanism to drop old weights, refactoring how weights are pushed and loaded (favoring DCP whole state dicts), and enhancing robustness with new utility functions and tests.

Weight Management Improvements

  • Added the drop_weights async function in apps/grpo/main.py to efficiently delete old model weights after each training step, preventing unnecessary storage growth. This function uses new utilities to locate and drop both DCP and individual parameter keys. ([1], [2])
  • Refactored the weight loading logic in PolicyWorker.update to prefer loading the entire state dict from a single DCP handle when available, falling back to individual parameters otherwise. This streamlines weight updates and reduces complexity. (src/forge/actors/policy.pyL563-R582)
  • Updated the weight saving logic in Trainer.push_weights to save the whole state dict as a single DCP handle when use_dcp is enabled, improving performance and consistency. (src/forge/actors/trainer.pyL347-R369)

Utility and Configuration Enhancements

  • Added new utility functions and improved the DcpHandle class in src/forge/actors/_torchstore_utils.py, including a robust drop() method to safely delete checkpoints and handle manifold storage cases. ([1], [2])
  • Updated configuration files to enable built-in vllm loading and DCP usage for both policy and trainer components, aligning the pipeline with new weight management strategies. ([1], [2])

Testing and Reliability

These changes collectively make the training pipeline more efficient, reliable, and easier to maintain by improving how model weights are stored, loaded, and cleaned up.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 27, 2025
@casteryh casteryh requested a review from Jack-Khuu September 27, 2025 23:54
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.fanout(training_step)
await policy.update_weights.fanout(training_step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if there is a possibility of needing "some" history of the weights. Can the RL loop be still alive after this statement finishes but the policy model goes down?

Can you elaborate on what you meant by "the RL loop be still alive"?

* metric logger simple example

* it works

* delete old files

* refactoring + docstrings

* docstring

* comments

* update method name

* no circular import

* update command

* update arg name

* move metric actor out of asyncio lock

* fix deregister

* lint

* docstring

* fix result extraction and add logger shutdown

* fix shutdown order

* simplification + docstrings

* bug fix + register if respawn

* it works

* use procmesh as key

* docstring

* remove protected imports

* create get_metric_logger

* call became fanout

* upstream changes

---------

Co-authored-by: Felipe Mello <felipemello@fb.com>
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one comment, but otherwise very helpful

return self._tokenizer.pad_token_id


async def drop_weights(version: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great function, but this is actually the kind of logic I don't care to see in the main.py file. Would it be possible to have this be part of torchstore / dcp itself (or a wrapper we write)? That way we can specify here "keep_last_n_weights".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

…model config (for Titan + vLLM) (#241)

* commit

* flag

* format

* nit

* nit
@Jack-Khuu
Copy link
Contributor

Jack-Khuu commented Sep 29, 2025

nit: Can you format the clickable links in the PR description? Copilot borked it

@casteryh
Copy link
Contributor Author

nit: Can you format the clickable links in the PR description? Copilot borked it

done!

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments and questions, but overall LGTM

self.param_names = None
return

import shutil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can import this at the top

import shutil

try:
shutil.rmtree(self.checkpoint_id, ignore_errors=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to suppress the errors here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. I was just thinking logging the error is fine and don't want to crash everything if delete is not successful. Let me know what you think.

loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
logger.info(
Copy link
Member

@joecummings joecummings Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire weight update timing is already calculated at the policy update level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but it's different from each worker's update time. The entire updating time is basically the longest among the workers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, we might just want to remove the top-level logging time I suppose

@joecummings
Copy link
Member

we'll need #255 to land

@casteryh casteryh merged commit a1714c3 into meta-pytorch:main Sep 29, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants