You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #3064
## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.
### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation
**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.
**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules
## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:
1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training
For more details see diff:D75853147 or PR #3057
Reviewed By: TroyGarden
Differential Revision: D76202371
0 commit comments