Skip to content

Conversation

@casteryh
Copy link
Contributor

@casteryh casteryh commented Oct 20, 2025

Summary

This PR enables shared memory multiprocess prefetch for DCP-based weight synchronization, resolving the TODO at generator.py:442.

Changes

Core functionality:

  • Added _fetch_weights_parallel() method - Unified parallel fetching logic for both DCP and non-DCP paths
  • Extended _WeightFetcher.fetch() method - Parameterized to handle both torchstore individual tensors and DCP checkpoints
  • Updated Generator.update_weights() - Automatically detects DCP usage and uses appropriate shared memory prefetch

Refactoring:

  • Reduced ~80 lines of code duplication by consolidating common parallelization and SharedTensor wrapping logic
  • Renamed checkpoint_id parameter to dcp_key for clarity (it's the torchstore key, not the filesystem checkpoint path)
  • Each fetcher now calls ts.get(dcp_key) to retrieve the DCP handle with metadata

Bug fixes:

  • Removed premature dcp_handle.drop() call - checkpoint files must be preserved for recovery

Performance Benefits

This change enables the same performance benefits for DCP-based weight sync that were previously only available for individual tensor loading:

  • Parallel loading across 8 fetcher processes (configurable via n_fetcher_procs)
  • Shared memory - single copy instead of N worker copies
  • Reduced latency during weight updates

Technical Details

DCP path:

  1. Generator calls ts.get(dcp_whole_state_dict_key) to get param names
  2. Passes the torchstore key to fetchers for parallel loading
  3. Each fetcher calls ts.get(dcp_key) to get the DCP handle (metadata + checkpoint path)
  4. Tensors loaded via load_tensor_from_dcp() and wrapped in SharedTensor
  5. Workers load from shared memory instead of each calling DCP load

Non-DCP path:

  1. Generator queries torchstore keys to get param names
  2. Passes version number to fetchers for parallel loading
  3. Each fetcher loads individual tensors via ts.get(param_key)
  4. Tensors wrapped in SharedTensor and handles returned
  5. Workers load from shared memory instead of each calling ts.get()

Testing

  • Syntax validation passed
  • All pre-commit hooks passed (trailing whitespace, AST check, flake8, ufmt, pydoclint)
  • No breaking changes to existing API

This PR was created by Claude Code on behalf of @casteryh

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 20, 2025
@casteryh casteryh marked this pull request as draft October 20, 2025 20:58
The DCP handle should not be dropped immediately after fetching weights to
shared memory. Dropping it will delete the checkpoint files on disk, which
we need to keep for potential recovery if something goes wrong. The checkpoint
cleanup should happen later when we're certain we don't need the checkpoint
for recovery.
- Renamed 'checkpoint_id' parameter to 'dcp_key' for clarity
  - The parameter is actually the torchstore key (e.g., 'policy_ver_X.dcp_whole_state_dict')
  - Not the actual checkpoint_id from the DCP handle itself
- Each fetcher now calls ts.get(dcp_key) to retrieve the DCP handle
  - This gives access to both metadata and the actual checkpoint path
  - More efficient than manually loading metadata in each fetcher
- Removed redundant metadata loading and DcpHandle construction code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant