Skip to content

Conversation

aliafzal
Copy link
Contributor

@aliafzal aliafzal commented Jun 7, 2025

Summary:

Summary:

This PR is an initial checkin which introduces ModelDeltaTracker.

ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

  1. Identifying which embedding rows were accessed during model execution
  2. Retrieving the latest delta or unique rows for a model
  3. Computing top-k changed embeddings
  4. Supporting streaming updated embeddings between systems during online training

The tracker works with ShardedEmbeddingCollection and ShardedEmbeddingBagCollection modules and supports different tracking modes (Adding support for optimizer modes in follow up diffs):

  • ID_ONLY: Only tracks which IDs were accessed
  • EMBEDDING: Tracks both IDs and their embedding values

Key features:

  • Multiple consumer support (each consumer can track its own state)
  • Configurable deletion policy for tracked data
  • Ability to retrieve delta information for specific consumers

This utility helps optimize training workflows by enabling systems to focus on the most recently changed embeddings rather than processing the entire embedding table.

Reviewed By: chouxi

Differential Revision: D75853147

aliafzal added 2 commits June 7, 2025 11:56
Summary:

# Summary
Introducing DeltaStore class which efficiently manages embedding table updates with the following features:
*   Tracks embedding table updates by table FQN with batch indexing
*   Supports multiple embedding update modes (NONE, FIRST, LAST)
*   Provides compaction functionality for calculating unique
*   Allows retrieval of unique/delta IDs per table with optional embedding values

## How lookups are preserved and fetched?
In DeltaStore, lookups are preserved in the `per_fqn_lookups` dictionary, which maps table FQNs to lists of `IndexedLookup` objects. Each `IndexedLookup` contains:

1.  `idx`: The batch index
2.  `ids`: Tensor of embedding IDs
3.  `embeddings`: Optional tensor of embedding values

Lookups are added via the `append` method and can be:

*   Deleted with the  `delete` method (up to a specific index)
*   Compacted with the `compact` method (merges lookups within a range)
*   Retrieved as unique/delta rows with the `get_delta` method

## This diffs:
1. delta_store.py includes all main logic to preserve, fetch, compact and delete
2. types.py includes required datatypes and enums
3. test_delta_store.py Includes test cases for compute, delete and compact methods

Reviewed By: TroyGarden

Differential Revision: D71130002
Summary:
# Summary:
This PR is an initial checkin which introduces ```ModelDeltaTracker```.

ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1.  Identifying which embedding rows were accessed during model execution
2.  Retrieving the latest delta or unique rows for a model
3.  Computing top-k changed embeddings
4.  Supporting streaming updated embeddings between systems during online training

The tracker works with ```ShardedEmbeddingCollection``` and ```ShardedEmbeddingBagCollection``` modules and supports different tracking modes (Adding support for optimizer modes in follow up diffs):

*   ID\_ONLY: Only tracks which IDs were accessed
*   EMBEDDING: Tracks both IDs and their embedding values

## Key features:

*   Multiple consumer support (each consumer can track its own state)
*   Configurable deletion policy for tracked data
*   Ability to retrieve delta information for specific consumers

This utility helps optimize training workflows by enabling systems to focus on the most recently changed embeddings rather than processing the entire embedding table.

Reviewed By: chouxi

Differential Revision: D75853147
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 7, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75853147

aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. 

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)
 

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training


For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. 

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)
 

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training


For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. 

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)
 

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training


For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 11, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 11, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 11, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 11, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 11, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 11, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 12, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 12, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 12, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 12, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 13, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal pushed a commit to aliafzal/torchrec that referenced this pull request Jun 13, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 13, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 14, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 14, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 14, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal pushed a commit to aliafzal/torchrec that referenced this pull request Jun 16, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 16, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 16, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 16, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Reviewed By: TroyGarden

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 16, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Reviewed By: TroyGarden

Differential Revision: D76202371
aliafzal pushed a commit to aliafzal/torchrec that referenced this pull request Jun 17, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371

Reviewed By: TroyGarden
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 17, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Reviewed By: TroyGarden

Differential Revision: D76202371
aliafzal pushed a commit to aliafzal/torchrec that referenced this pull request Jun 17, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Differential Revision: D76202371

Reviewed By: TroyGarden
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 17, 2025
Summary:
Pull Request resolved: meta-pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR meta-pytorch#3057

Reviewed By: TroyGarden

Differential Revision: D76202371
facebook-github-bot pushed a commit that referenced this pull request Jun 18, 2025
Summary:
Pull Request resolved: #3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker

**Custom Callables for Tracking**:
* Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable
* Added callable registration methods in embedding modules

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR #3057

Reviewed By: TroyGarden

Differential Revision: D76202371

fbshipit-source-id: 5dc1c2459ee1821b246652c3edd6423695630023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants