Enable shared memory prefetch for DCP weight sync #468
+106
−21
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR enables shared memory multiprocess prefetch for DCP-based weight synchronization, resolving the TODO at
generator.py:442.Changes
Core functionality:
_fetch_weights_parallel()method - Unified parallel fetching logic for both DCP and non-DCP paths_WeightFetcher.fetch()method - Parameterized to handle both torchstore individual tensors and DCP checkpointsGenerator.update_weights()- Automatically detects DCP usage and uses appropriate shared memory prefetchRefactoring:
checkpoint_idparameter todcp_keyfor clarity (it's the torchstore key, not the filesystem checkpoint path)ts.get(dcp_key)to retrieve the DCP handle with metadataBug fixes:
dcp_handle.drop()call - checkpoint files must be preserved for recoveryPerformance Benefits
This change enables the same performance benefits for DCP-based weight sync that were previously only available for individual tensor loading:
n_fetcher_procs)Technical Details
DCP path:
ts.get(dcp_whole_state_dict_key)to get param namests.get(dcp_key)to get the DCP handle (metadata + checkpoint path)load_tensor_from_dcp()and wrapped inSharedTensorNon-DCP path:
ts.get(param_key)SharedTensorand handles returnedts.get()Testing
This PR was created by Claude Code on behalf of @casteryh