-
Notifications
You must be signed in to change notification settings - Fork 563
provide basic integration of trec 2d emb to training pipeline #2929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D70988786 |
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
4a38377
to
fbb9472
Compare
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
fbb9472
to
7695cc8
Compare
This pull request was exported from Phabricator. Differential Revision: D70988786 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D70988786 |
…ytorch#2929) Summary: Pull Request resolved: meta-pytorch#2929 * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
7695cc8
to
4133ec5
Compare
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
This pull request was exported from Phabricator. Differential Revision: D70988786 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR integrates a basic configuration to support DMP collection syncing via a new configuration value in the training pipelines. Key changes include adding helper functions (get_class_name and assert_instance), updating the sparse train pipeline to optionally sync DMPs every N batches, and adding tests to validate the new sync functionality.
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
torchrec/distributed/utils.py | Added helper functions get_class_name and assert_instance for improved type error messages. |
torchrec/distributed/types.py | Introduced a stub for prefetch to support future extensions. |
torchrec/distributed/train_pipeline/utils.py | Updated type checks and stream handling using assert_instance. |
torchrec/distributed/train_pipeline/train_pipelines.py | Integrated a new sync_embeddings function that optionally syncs DMP collection based on a configurable interval. |
torchrec/distributed/train_pipeline/tests/test_train_pipelines.py | Added tests to verify proper sync behavior and disable sync when DMP is not used. |
torchrec/distributed/embedding_types.py | Updated the prefetch method’s parameter type to accept Multistreamable instances. |
|
||
|
||
def assert_instance(obj: object, t: Type[_T]) -> _T: | ||
assert isinstance(obj, t), f"Got {get_class_name(obj)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider including the expected type in the assertion error message (e.g., 'Expected {t}, got {get_class_name(obj)}') to improve debuggability.
assert isinstance(obj, t), f"Got {get_class_name(obj)}" | |
assert isinstance(obj, t), f"Expected {t.__name__}, got {get_class_name(obj)}" |
Copilot uses AI. Check for mistakes.
f"{self.__class__.__name__} does not support context (not expected). " | ||
"Embedding weight sync is disabled." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider clarifying or documenting in the warning message when the context is absent so that it’s clear under what conditions embedding sync is disabled.
f"{self.__class__.__name__} does not support context (not expected). " | |
"Embedding weight sync is disabled." | |
f"{self.__class__.__name__}: Embedding weight synchronization requires a valid " | |
"TrainPipelineContext. No context was provided, so embedding sync is disabled. " | |
"Ensure that a TrainPipelineContext is passed to enable this feature." |
Copilot uses AI. Check for mistakes.
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
This pull request was exported from Phabricator. Differential Revision: D70988786 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D70988786 |
…ytorch#2929) Summary: Pull Request resolved: meta-pytorch#2929 * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
This pull request was exported from Phabricator. Differential Revision: D70988786 |
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
This pull request was exported from Phabricator. Differential Revision: D70988786 |
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
This pull request was exported from Phabricator. Differential Revision: D70988786 |
…ytorch#2929) Summary: Pull Request resolved: meta-pytorch#2929 * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
…ytorch#2929) Summary: * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
…ytorch#2929) Summary: Pull Request resolved: meta-pytorch#2929 * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786
This pull request was exported from Phabricator. Differential Revision: D70988786 |
Summary:
dmp_collection_sync_interval_batches
as a config value to SDD pipeline (and semi sync)None
Differential Revision: D70988786