From 9bc09d7af2e50e0b2bbd3b337229a3a95ec14b00 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Mon, 20 Oct 2025 21:09:04 -0700 Subject: [PATCH 1/6] Adding support for tracking optimizers states in Model Delta Tracker. (#3143) Summary: X-link: https://github.com/pytorch/torchrec/pull/3143 ### Overview This diff adds support for tracking optimizer states in the Model Delta Tracker system. It introduces a new tracking mode called `MOMENTUM_LAST` that enables tracking of momentum values from optimizers to support approximate top-k delta-row selection. ### Key Changes #### 1. Optimizer State Tracking Support * To support tracking of optimizer states I have added `optim_state_tracker_fn` attribute to `GroupedEmbeddingsLookup` and `GroupedPooledEmbeddingsLookup` classes responsible for traversing over the BatchedFused modules. * Implemented `register_optim_state_tracker_fn()` method in both classes to register the trackable callable * Tracking calls are invoked after each lookup operation. #### 2. Model Delta Tracker Changes * Added `record_momentum()` method to track momentum values from optimizer states and its support in record_lookup function. * Added validation and optim tracker function logic to support the new `MOMENTUM_LAST` mode #### 3. New Tracking Mode * Added `TrackingMode.MOMENTUM_LAST` to [`**types.py**`](command:code-compose.open?%5B%22%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/fbcode/torchrec/distributed/model_tracker/types.py") * Maps to `EmbdUpdateMode.LAST` to capture the most recent momentum values Differential Revision: D76868111 --- torchrec/distributed/embedding.py | 2 +- torchrec/distributed/embedding_lookup.py | 51 ++++++- torchrec/distributed/embedding_types.py | 6 +- torchrec/distributed/embeddingbag.py | 2 +- .../model_tracker/model_delta_tracker.py | 82 +++++++++- .../tests/test_model_delta_tracker.py | 144 ++++++++++++++++++ torchrec/distributed/model_tracker/types.py | 10 +- 7 files changed, 279 insertions(+), 18 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 673fbccae..49afcfe7f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1587,7 +1587,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 891dd0b02..fa8338b7f 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -10,7 +10,7 @@ import logging from abc import ABC from collections import OrderedDict -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -208,6 +208,10 @@ def __init__( ) self.grouped_configs = grouped_configs + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -315,7 +319,13 @@ def forward( self._feature_splits, ) for emb_op, features in zip(self._emb_modules, features_by_group): - embeddings.append(emb_op(features).view(-1)) + lookup = emb_op(features).view(-1) + embeddings.append(lookup) + + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) @@ -420,6 +430,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]): """ @@ -519,6 +542,10 @@ def __init__( if scale_weight_gradients and get_gradient_division() else 1 ) + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -678,7 +705,12 @@ def forward( features._weights, self._scale_gradient_factor ) - embeddings.append(emb_op(features)) + lookup = emb_op(features) + embeddings.append(lookup) + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) if features.variable_stride_per_key() and len(self._emb_modules) > 1: stride_per_rank_per_key = list( @@ -811,6 +843,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class MetaInferGroupedEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 5434a3203..fbc17b408 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,7 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -444,14 +444,14 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], ) -> None: """ Register a function to be called after lookup is done. This is used for tracking the lookup results and optimizer states. Args: - record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. """ if self.post_lookup_tracker_fn is not None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 7c822051d..5b3b846e1 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 905bf7648..29c854175 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -13,7 +13,12 @@ import torch from torch import nn + from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.embedding_lookup import ( + GroupedEmbeddingsLookup, + GroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.model_tracker.delta_store import DeltaStore from torchrec.distributed.model_tracker.types import ( @@ -27,9 +32,16 @@ # Only IDs are tracked, no additional state is stored. TrackingMode.ID_ONLY: EmbdUpdateMode.NONE, # TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that - # the earliest embedding values are stored since the last checkpoint or snapshot. - # This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk(). + # the earliest embedding values are stored since the last checkpoint + # or snapshot. This mode is used for computing topk delta rows, which + # is currently achieved by running (new_emb - old_emb).norm().topk(). TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST, + # TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that + # the most recent momentum values—capturing the accumulated gradient + # direction and magnitude—are stored since the last batch. + # This mode supports approximate top-k delta-row selection, can be + # obtained by running momentum.norm().topk(). + TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST, } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. @@ -141,7 +153,9 @@ def trigger_compaction(self) -> None: # Update the current compact index to the end index to avoid duplicate compaction. self.curr_compact_index = end_idx - def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: + def record_lookup( + self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -152,6 +166,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode). Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. """ @@ -162,7 +177,9 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch. elif self._mode == TrackingMode.EMBEDDING: self.record_embeddings(kjt, states) - + # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. + elif self._mode == TrackingMode.MOMENTUM_LAST: + self.record_momentum(emb_module, kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported") @@ -228,6 +245,39 @@ def record_embeddings( states=torch.cat(per_table_emb[table_fqn]), ) + def record_momentum( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + # FIXME: this is the momentum from last iteration, use momentum from current iter + # for correctness. + # pyre-ignore Undefined attribute [16]: + momentum = emb_module._emb_module.momentum1_dev + # FIXME: support multiple tables per group, information can be extracted from + # module._config (i.e., GroupedEmbeddingConfig) + # pyre-ignore Undefined attribute [16]: + states = momentum.view(-1, emb_module._config.embedding_dims()[0])[ + kjt.values() + ].norm(dim=1) + + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are @@ -380,13 +430,31 @@ def _clean_fqn_fn(self, fqn: str) -> str: def _validate_and_init_tracker_fns(self) -> None: "To validate the mode is supported for the given module" for module in self.tracked_modules.values(): + # EMBEDDING mode is only supported for ShardedEmbeddingCollection assert not ( isinstance(module, ShardedEmbeddingBagCollection) and self._mode == TrackingMode.EMBEDDING ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings." - # register post lookup function - # pyre-ignore[29] - module.register_post_lookup_tracker_fn(self.record_lookup) + + if ( + self._mode == TrackingMode.ID_ONLY + or self._mode == TrackingMode.EMBEDDING + ): + # register post lookup function + # pyre-ignore[29] + module.register_post_lookup_tracker_fn(self.record_lookup) + elif self._mode == TrackingMode.MOMENTUM_LAST: + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + else: + raise NotImplementedError( + f"Tracking mode {self._mode} is not supported" + ) # register auto compaction function at odist if self._auto_compact: # pyre-ignore[29] diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index a92a9b286..c3f641b98 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -1463,6 +1463,101 @@ def test_multiple_consumers( output_params=output_params, ) + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_momentum( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_momentum, + world_size=self.world_size, + test_params=test_params, + ) + def _test_fqn_to_feature_names( rank: int, @@ -1859,3 +1954,52 @@ def _test_multiple_consumer( and returned.allclose(expected_ids), f"{i=}, {table_fqn=}, mismatch {returned=} vs {expected_ids=}", ) + + +def _test_duplication_with_momentum( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + """ + Test momentum tracking functionality in model delta tracker. + + Validates that the tracker correctly captures and stores momentum values from + optimizer states when using TrackingMode.MOMENTUM_LAST mode. + """ + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + ) + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + dt_model_opt.step() + baseline_opt.step() + + delta_rows = dt.get_delta() + for table_fqn in table_fqns_list: + ids = delta_rows[table_fqn].ids + states = none_throws(delta_rows[table_fqn].states) + + unittest.TestCase().assertTrue(states is not None) + unittest.TestCase().assertTrue(ids.numel() == states.numel()) + unittest.TestCase().assertTrue(bool((states != 0).all().item())) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index cec95af91..a5a56514c 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -41,13 +41,17 @@ class TrackingMode(Enum): Tracking mode for ``ModelDeltaTracker``. Enums: - ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. - EMBEDDING: Tracks both row IDs and their corresponding embedding values, - enabling precise top-k result calculations. However, this option comes with increased memory usage. + ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. + EMBEDDING: Tracks both row IDs and their corresponding embedding values, + enabling precise top-k result calculations. However, this option comes + with increased memory usage. + MOMENTUM_LAST: Tracks both row IDs and their corresponding momentum values. This mode + supports approximate top-k delta-row selection. """ ID_ONLY = "id_only" EMBEDDING = "embedding" + MOMENTUM_LAST = "momentum_last" class EmbdUpdateMode(Enum): From 4820931053ade8df71a401809e9f4e7a58672b4a Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Mon, 20 Oct 2025 21:09:04 -0700 Subject: [PATCH 2/6] Adding support for MOMENTUM_DIFF and ROWWISE_ADAGRAD optimizer states (#3144) Summary: X-link: https://github.com/pytorch/torchrec/pull/3144 This diff extends the Model Delta Tracker to support two new tracking modes: `MOMENTUM_DIFF` and `ROWWISE_ADAGRAD`, which enable tracking of rowwise optimizer states for more sophisticated gradient analysis. Differential Revision: D76918891 --- .../model_tracker/model_delta_tracker.py | 134 ++++++++++++-- .../tests/test_model_delta_tracker.py | 170 ++++++++++++++++++ torchrec/distributed/model_tracker/types.py | 4 + 3 files changed, 296 insertions(+), 12 deletions(-) diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 29c854175..74e0bb7c7 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -8,14 +8,20 @@ # pyre-strict import logging as logger from collections import Counter, OrderedDict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Tuple import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn +from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding from torchrec.distributed.embedding import ShardedEmbeddingCollection from torchrec.distributed.embedding_lookup import ( + BatchedFusedEmbeddingBag, GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup, ) @@ -26,6 +32,8 @@ EmbdUpdateMode, TrackingMode, ) +from torchrec.distributed.utils import none_throws + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = { @@ -42,6 +50,14 @@ # This mode supports approximate top-k delta-row selection, can be # obtained by running momentum.norm().topk(). TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST, + # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row. + # Within each publishing interval, we track the starting value of this running + # sum on all used rows and then do a lookup when ``get_delta`` is called to query + # the latest sum. Then we can compute the delta of the two values and return them + # together with the row ids. + TrackingMode.MOMENTUM_DIFF: EmbdUpdateMode.FIRST, + # The same as MOMENTUM_DIFF. Adding for backward compatibility. + TrackingMode.ROWWISE_ADAGRAD: EmbdUpdateMode.FIRST, } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. @@ -99,6 +115,7 @@ def __init__( # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} self.feature_to_fqn: Dict[str, str] = {} # Generate the mapping from FQN to feature names. self.fqn_to_feature_names() @@ -180,6 +197,11 @@ def record_lookup( # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. elif self._mode == TrackingMode.MOMENTUM_LAST: self.record_momentum(emb_module, kjt) + elif ( + self._mode == TrackingMode.MOMENTUM_DIFF + or self._mode == TrackingMode.ROWWISE_ADAGRAD + ): + self.record_rowwise_optim_state(emb_module, kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported") @@ -278,6 +300,60 @@ def record_momentum( states=per_key_states, ) + def record_rowwise_optim_state( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + opt_states: List[List[torch.Tensor]] = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `split_optimizer_states`. + emb_module._emb_module.split_optimizer_states() + ) + proxy: torch.Tensor = torch.cat([state[0] for state in opt_states]) + states = proxy[kjt.values()] + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + table_fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def get_latest(self) -> Dict[str, torch.Tensor]: + ret: Dict[str, torch.Tensor] = {} + for module in self.tracked_modules.values(): + # pyre-fixme[29]: + for lookup in module._lookups: + for embs_module in lookup._emb_modules: + assert isinstance( + embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding) + ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but {type(embs_module)} found" + tbe = embs_module._emb_module + + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + table_names = [t.name for t in embs_module._config.embedding_tables] + opt_states = tbe.split_optimizer_states() + assert len(table_names) == len(opt_states) + + for i, table_name in enumerate(table_names): + emb_fqn = self.table_to_fqn[table_name] + table_state = opt_states[i][0] + assert ( + emb_fqn not in ret + ), f"a table with {emb_fqn} already exists" + ret[emb_fqn] = table_state + + return ret + def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are @@ -289,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso per_table_delta_rows = self.get_delta(consumer) return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()} - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_delta( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, DeltaRows]: """ Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN. @@ -314,6 +396,17 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: self.per_consumer_batch_idx[consumer] = index_end if self._delete_on_read: self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + if self._mode in (TrackingMode.MOMENTUM_DIFF, TrackingMode.ROWWISE_ADAGRAD): + square_sum_map = self.get_latest() + for fqn, rows in tracker_rows.items(): + assert ( + fqn in square_sum_map + ), f"{fqn} not found in {square_sum_map.keys()}" + # pyre-fixme[58]: `-` is not supported for operand types `Tensor` + # and `Optional[Tensor]`. + rows.states = square_sum_map[fqn][rows.ids] - rows.states + return tracker_rows def get_tracked_modules(self) -> Dict[str, nn.Module]: @@ -330,7 +423,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: return self._fqn_to_feature_map table_to_feature_names: Dict[str, List[str]] = OrderedDict() - table_to_fqn: Dict[str, str] = OrderedDict() for fqn, named_module in self._model.named_modules(): split_fqn = fqn.split(".") # Skipping partial FQNs present in fqns_to_skip @@ -356,13 +448,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: # will incorrectly match fqn with all the table names that have the same prefix if table_name in split_fqn: embedding_fqn = self._clean_fqn_fn(fqn) - if table_name in table_to_fqn: + if table_name in self.table_to_fqn: # Sanity check for validating that we don't have more then one table mapping to same fqn. logger.warning( - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" ) - table_to_fqn[table_name] = embedding_fqn - logger.info(f"Table to fqn: {table_to_fqn}") + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") flatten_names = [ name for names in table_to_feature_names.values() for name in names ] @@ -375,15 +467,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() for table_name in table_to_feature_names: - if table_name not in table_to_fqn: + if table_name not in self.table_to_fqn: # This is likely unexpected, where we can't locate the FQN associated with this table. logger.warning( - f"Table {table_name} not found in {table_to_fqn}, skipping" + f"Table {table_name} not found in {self.table_to_fqn}, skipping" ) continue - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ - table_name - ] + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) self._fqn_to_feature_map = fqn_to_feature_names return fqn_to_feature_names @@ -451,6 +543,24 @@ def _validate_and_init_tracker_fns(self) -> None: (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), ) lookup.register_optim_state_tracker_fn(self.record_lookup) + elif ( + self._mode == TrackingMode.ROWWISE_ADAGRAD + or self._mode == TrackingMode.MOMENTUM_DIFF + ): + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) and all( + # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM + for emb in lookup._emb_modules + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) else: raise NotImplementedError( f"Tracking mode {self._mode} is not supported" diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index c3f641b98..9362357fe 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -13,6 +13,9 @@ import torch import torchrec from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from parameterized import parameterized from torch import nn @@ -1558,6 +1561,101 @@ def test_duplication_with_momentum( test_params=test_params, ) + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_DIFF, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.ROWWISE_ADAGRAD, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_rowwise_adagrad( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_rowwise_adagrad, + world_size=self.world_size, + test_params=test_params, + ) + def _test_fqn_to_feature_names( rank: int, @@ -2003,3 +2101,75 @@ def _test_duplication_with_momentum( unittest.TestCase().assertTrue(states is not None) unittest.TestCase().assertTrue(ids.numel() == states.numel()) unittest.TestCase().assertTrue(bool((states != 0).all().item())) + + +def _test_duplication_with_rowwise_adagrad( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + optimizer_type=OptimType.EXACT_ROWWISE_ADAGRAD, + ) + + # read momemtum directly from the table + tbe: SplitTableBatchedEmbeddingBagsCodegen = ( + ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `ec`. + dt_model._dmp_wrapped_module.module.ec._lookups[0] + ._emb_modules[0] + .emb_module + ) + if test_params.embedding_config_type == EmbeddingConfig + else ( + dt_model._dmp_wrapped_module.module.ebc._lookups[0] # pyre-ignore + ._emb_modules[0] + .emb_module + ) + ) + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + start_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + + dt_model_opt.step() + baseline_opt.step() + + end_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + delta_rows = dt.get_delta() + table_fqn = table_fqns_list[0] + + ids = delta_rows[table_fqn].ids + tracked_momentum = none_throws(delta_rows[table_fqn].states) + unittest.TestCase().assertTrue(tracked_momentum is not None) + unittest.TestCase().assertTrue(ids.numel() == tracked_momentum.numel()) + unittest.TestCase().assertTrue(bool((tracked_momentum != 0).all().item())) + + expected_momentum = end_momentums[ids] - start_momentums[ids] + unittest.TestCase().assertTrue(tracked_momentum.allclose(expected_momentum)) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index a5a56514c..43a1b9223 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -47,11 +47,15 @@ class TrackingMode(Enum): with increased memory usage. MOMENTUM_LAST: Tracks both row IDs and their corresponding momentum values. This mode supports approximate top-k delta-row selection. + MOMENTUM_DIFF: Tracks both row IDs and their corresponding momentum difference values. + ROWWISE_ADAGRAD: Tracks both row IDs and their corresponding rowwise adagrad states. """ ID_ONLY = "id_only" EMBEDDING = "embedding" MOMENTUM_LAST = "momentum_last" + MOMENTUM_DIFF = "momentum_diff" + ROWWISE_ADAGRAD = "rowwise_adagrad" class EmbdUpdateMode(Enum): From 988646d9a0f988460562992a9f40af59192212dd Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Mon, 20 Oct 2025 21:09:04 -0700 Subject: [PATCH 3/6] Update DeltaStore to be Generic (#3468) Summary: Make DeltaStore generic to allow use case specific custom implementations internal General Context: We are in the process of transition to a unified DeltaTracker and this is 1/n diffs representing changes towards the transition. Specific Context: DeltaTracker utilizes Memstore to preserve and compact lookups extracted during embedding lookups. As part of transitioning to a common DeltaTracker, we are adding a generic DeltaStore. Memstore will extend from Generic DeltaStore, allowing both MRS and OSS DeltaTrackers to be easily integrated into training frameworks. Differential Revision: D80614364 --- .../distributed/model_tracker/delta_store.py | 87 +++++++++++++++++-- .../model_tracker/model_delta_tracker.py | 12 +-- .../model_tracker/tests/test_delta_store.py | 14 +-- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index 34ea8c88f..c302eb803 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -7,6 +7,7 @@ # pyre-strict +from abc import ABC, abstractmethod from bisect import bisect_left from typing import Dict, List, Optional @@ -67,34 +68,106 @@ def _compute_unique_rows( return DeltaRows(ids=unique_ids, states=unique_states) -class DeltaStore: +class DeltaStore(ABC): """ - DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across - various batches during training, designed to be used with TorchRecs ModelDeltaTracker. - It maintains a CUDA in-memory representation of requested ids and embeddings/states, + DeltaStore is an abstract base class that defines the interface for storing and managing + local delta (row) updates for embeddings/states across various batches during training. + + Implementations should maintain a representation of requested ids and embeddings/states, providing a way to compact and get delta updates for each embedding table. The class supports different embedding update modes (NONE, FIRST, LAST) to determine how to handle duplicate ids when compacting or retrieving embeddings. + """ + + @abstractmethod + def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: + pass + + @abstractmethod + def append( + self, + batch_idx: int, + fqn: str, + ids: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + """ + Append a batch of ids and states to the store for a specific table. + + Args: + batch_idx: The batch index + table_fqn: The fully qualified name of the table + ids: The tensor of ids to append + states: Optional tensor of states to append + """ + pass + + @abstractmethod + def delete(self, up_to_idx: Optional[int] = None) -> None: + """ + Delete all idx from the store up to `up_to_idx` + + Args: + up_to_idx: Optional index up to which to delete lookups + """ + pass + @abstractmethod + def compact(self, start_idx: int, end_idx: int) -> None: + """ + Compact (ids, embeddings) in batch index range from start_idx to end_idx. + + Args: + start_idx: The starting batch index + end_idx: The ending batch index + """ + pass + + @abstractmethod + def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + """ + Return all unique/delta ids per table from the Delta Store. + + Args: + from_idx: The batch index from which to get deltas + + Returns: + A dictionary mapping table FQNs to their delta rows + """ + pass + + +class DeltaStoreTrec(DeltaStore): + """ + DeltaStoreTrec is a concrete implementation of DeltaStore that stores and manages + local delta (row) updates for embeddings/states across various batches during training, + designed to be used with TorchRecs ModelDeltaTracker. + + It maintains a CUDA in-memory representation of requested ids and embeddings/states, + providing a way to compact and get delta updates for each embedding table. + + The class supports different embedding update modes (NONE, FIRST, LAST) to determine + how to handle duplicate ids when compacting or retrieving embeddings. """ def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: + super().__init__(embdUpdateMode) self.embdUpdateMode = embdUpdateMode self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} def append( self, batch_idx: int, - table_fqn: str, + fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], ) -> None: - table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, []) + table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) ) - self.per_fqn_lookups[table_fqn] = table_fqn_lookup + self.per_fqn_lookups[fqn] = table_fqn_lookup def delete(self, up_to_idx: Optional[int] = None) -> None: """ diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 74e0bb7c7..d84fe31db 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -26,7 +26,7 @@ GroupedPooledEmbeddingsLookup, ) from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection -from torchrec.distributed.model_tracker.delta_store import DeltaStore +from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec from torchrec.distributed.model_tracker.types import ( DeltaRows, EmbdUpdateMode, @@ -122,7 +122,7 @@ def __init__( # Validate is the mode is supported for the given module and initialize tracker functions self._validate_and_init_tracker_fns() - self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode]) + self.store: DeltaStoreTrec = DeltaStoreTrec(UPDATE_MODE_MAP[self._mode]) # Mapping feature name to corresponding FQNs. This is used for retrieving # the FQN associated with a given feature name in record_lookup(). @@ -222,7 +222,7 @@ def record_ids(self, kjt: KeyedJaggedTensor) -> None: for table_fqn, ids_list in per_table_ids.items(): self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=torch.cat(ids_list), states=None, ) @@ -262,7 +262,7 @@ def record_embeddings( for table_fqn, ids_list in per_table_ids.items(): self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=torch.cat(ids_list), states=torch.cat(per_table_emb[table_fqn]), ) @@ -295,7 +295,7 @@ def record_momentum( per_key_states = states[offsets[i] : offsets[i + 1]] self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=fqn, + fqn=fqn, ids=kjt[key].values(), states=per_key_states, ) @@ -323,7 +323,7 @@ def record_rowwise_optim_state( per_key_states = states[offsets[i] : offsets[i + 1]] self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=fqn, + fqn=fqn, ids=kjt[key].values(), states=per_key_states, ) diff --git a/torchrec/distributed/model_tracker/tests/test_delta_store.py b/torchrec/distributed/model_tracker/tests/test_delta_store.py index 3401089fd..2684a12dc 100644 --- a/torchrec/distributed/model_tracker/tests/test_delta_store.py +++ b/torchrec/distributed/model_tracker/tests/test_delta_store.py @@ -15,7 +15,7 @@ from parameterized import parameterized from torchrec.distributed.model_tracker.delta_store import ( _compute_unique_rows, - DeltaStore, + DeltaStoreTrec, ) from torchrec.distributed.model_tracker.types import ( DeltaRows, @@ -24,7 +24,7 @@ ) -class DeltaStoreTest(unittest.TestCase): +class DeltaStoreTrecTest(unittest.TestCase): # pyre-fixme[2]: Parameter must be annotated. def __init__(self, methodName="runTest") -> None: super().__init__(methodName) @@ -188,12 +188,12 @@ class AppendDeleteTestParams: def test_append_and_delete( self, _test_name: str, test_params: AppendDeleteTestParams ) -> None: - delta_store = DeltaStore() + delta_store = DeltaStoreTrec() for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): for lookup in lookup_list: delta_store.append( batch_idx=lookup.batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=lookup.ids, states=lookup.states, ) @@ -783,15 +783,15 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: """ Test the compact method of DeltaStore. """ - # Create a DeltaStore with the specified embdUpdateMode - delta_store = DeltaStore(embdUpdateMode=test_params.embdUpdateMode) + # Create a DeltaStoreTrec with the specified embdUpdateMode + delta_store = DeltaStoreTrec(embdUpdateMode=test_params.embdUpdateMode) # Populate the DeltaStore with the test lookups for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): for lookup in lookup_list: delta_store.append( batch_idx=lookup.batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=lookup.ids, states=lookup.states, ) From a771311298dff6513fd467bc54e3de18d2873d41 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Mon, 20 Oct 2025 21:09:04 -0700 Subject: [PATCH 4/6] Update DeltaStore API (#3469) Summary: internal General Context: We are in the process of transition to a unified DeltaTracker and this is 2/n diffs representing changes towards the transition. Specific Context: Update DeltaStore APIs to match Memstore APIs for backward compatibility. Differential Revision: D80614586 --- torchrec/distributed/model_parallel.py | 4 ++-- torchrec/distributed/model_tracker/delta_store.py | 4 ++-- .../distributed/model_tracker/model_delta_tracker.py | 6 +++--- .../model_tracker/tests/test_delta_store.py | 4 ++-- .../model_tracker/tests/test_model_delta_tracker.py | 12 ++++++------ 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 7c9c45824..deddf4b0a 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -466,14 +466,14 @@ def get_model_tracker(self) -> ModelDeltaTracker: ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." return self.model_delta_tracker - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_unique(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: """ Returns the delta rows for the given consumer. """ assert ( self.model_delta_tracker is not None ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." - return self.model_delta_tracker.get_delta(consumer) + return self.model_delta_tracker.get_unique(consumer) def sparse_grad_parameter_names( self, destination: Optional[List[str]] = None, prefix: str = "" diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index c302eb803..404bd1d83 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -125,7 +125,7 @@ def compact(self, start_idx: int, end_idx: int) -> None: pass @abstractmethod - def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]: """ Return all unique/delta ids per table from the Delta Store. @@ -224,7 +224,7 @@ def compact(self, start_idx: int, end_idx: int) -> None: ) self.per_fqn_lookups = new_per_fqn_lookups - def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]: r""" Return all unique/delta ids per table from the Delta Store. """ diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index d84fe31db..3f0a6541b 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -362,10 +362,10 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso Args: consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer. """ - per_table_delta_rows = self.get_delta(consumer) + per_table_delta_rows = self.get_unique(consumer) return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()} - def get_delta( + def get_unique( self, consumer: Optional[str] = None, top_percentage: Optional[float] = 1.0, @@ -390,7 +390,7 @@ def get_delta( # and index_start could be equal to index_end, in which case we should not compact again. if index_start < index_end: self.compact(index_start, index_end) - tracker_rows = self.store.get_delta( + tracker_rows = self.store.get_unique( from_idx=self.per_consumer_batch_idx[consumer] ) self.per_consumer_batch_idx[consumer] = index_end diff --git a/torchrec/distributed/model_tracker/tests/test_delta_store.py b/torchrec/distributed/model_tracker/tests/test_delta_store.py index 2684a12dc..a8dc40d50 100644 --- a/torchrec/distributed/model_tracker/tests/test_delta_store.py +++ b/torchrec/distributed/model_tracker/tests/test_delta_store.py @@ -806,8 +806,8 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: delta_store.compact( start_idx=test_params.start_idx, end_idx=test_params.end_idx ) - # Verify the result using get_delta method - delta_result = delta_store.get_delta() + # Verify the result using get_unique method + delta_result = delta_store.get_unique() # compare all fqns in the result for table_fqn, delta_rows in test_params.expected_delta.items(): diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index 9362357fe..b9bab879a 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -447,7 +447,7 @@ def test_fqn_to_feature_names( ), ), ( - "get_delta", + "get_unique", ModelDeltaTrackerInputTestParams( embedding_config_type=EmbeddingConfig, embedding_tables=[ @@ -464,7 +464,7 @@ def test_fqn_to_feature_names( model_tracker_config=ModelTrackerConfig(), ), TrackerNotInitOutputTestParams( - dmp_tracker_atter="get_delta", + dmp_tracker_atter="get_unique", ), ), ] @@ -1843,7 +1843,7 @@ def _test_embedding_mode( tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) @@ -1964,7 +1964,7 @@ def _test_multiple_get( unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() # Verify that the current batch index is correct unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1) @@ -2093,7 +2093,7 @@ def _test_duplication_with_momentum( dt_model_opt.step() baseline_opt.step() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() for table_fqn in table_fqns_list: ids = delta_rows[table_fqn].ids states = none_throws(delta_rows[table_fqn].states) @@ -2162,7 +2162,7 @@ def _test_duplication_with_rowwise_adagrad( end_momentums = tbe.split_optimizer_states()[0][0].detach().clone() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() table_fqn = table_fqns_list[0] ids = delta_rows[table_fqn].ids From 32253e29883fa188f52039d972586d0e30636a6f Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Mon, 20 Oct 2025 21:09:04 -0700 Subject: [PATCH 5/6] Update ModelDeltaTracker to be Generic (#3470) Summary: Make ModelDeltaTracker generic to allow use case specific custom implementations internal General Context: We are in the process of transition to a unified DeltaTracker and this is 3/n diffs representing changes towards the transition. Specific Context: DeltaTracker implements primitives to allow tracking of embedding ids and states to optimize checkpointing and embedding freshness. As part of transitioning to a common DeltaTracker, we are adding a generic ModelDeltaTracker. MRS DeltaTracker will extend from Generic ModelDeltaTracker. Differential Revision: D80614689 --- torchrec/distributed/model_parallel.py | 10 +-- .../model_tracker/model_delta_tracker.py | 72 +++++++++++++++++-- .../tests/test_model_delta_tracker.py | 4 +- 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index deddf4b0a..d3280a209 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -29,7 +29,7 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size -from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology @@ -293,7 +293,7 @@ def __init__( if init_data_parallel: self.init_data_parallel() - self.model_delta_tracker: Optional[ModelDeltaTracker] = ( + self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = ( self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module) if model_tracker_config is not None else None @@ -369,9 +369,9 @@ def _init_dmp(self, module: nn.Module) -> nn.Module: def _init_delta_tracker( self, model_tracker_config: ModelTrackerConfig, module: nn.Module - ) -> ModelDeltaTracker: + ) -> ModelDeltaTrackerTrec: # Init delta tracker if config is provided - return ModelDeltaTracker( + return ModelDeltaTrackerTrec( model=module, consumers=model_tracker_config.consumers, delete_on_read=model_tracker_config.delete_on_read, @@ -456,7 +456,7 @@ def init_parameters(module: nn.Module) -> None: module.apply(init_parameters) - def get_model_tracker(self) -> ModelDeltaTracker: + def get_model_tracker(self) -> ModelDeltaTrackerTrec: """ Returns the model tracker if it exists. """ diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 3f0a6541b..2d1e9af97 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -7,6 +7,7 @@ # pyre-strict import logging as logger +from abc import ABC, abstractmethod from collections import Counter, OrderedDict from typing import Dict, Iterable, List, Optional, Tuple @@ -64,10 +65,73 @@ SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection) -class ModelDeltaTracker: +class ModelDeltaTracker(ABC): r""" - ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support + Abstract base class for ModelDeltaTracker that provides a way to track and retrieve unique IDs for supported modules, + along with optional support for tracking corresponding embeddings or states. This is useful for identifying and + retrieving the latest delta or unique rows for a given model, which can help compute topk or to stream updated + embeddings from predictors to trainers during online training. + + """ + + DEFAULT_CONSUMER: str = "default" + + @abstractmethod + def record_lookup( + self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + ) -> None: + """ + Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. + + Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. + kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. + states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. + """ + pass + + @abstractmethod + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + """ + Return a dictionary of hit local IDs for each sparse feature. + + Args: + consumer (str, optional): The consumer to retrieve unique IDs for. + """ + pass + + @abstractmethod + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, DeltaRows]: + """ + Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. + + Args: + consumer (str, optional): The consumer to retrieve delta values for. + """ + pass + + @abstractmethod + def clear(self, consumer: Optional[str] = None) -> None: + """ + Clear tracked IDs for a given consumer. + + Args: + consumer (str, optional): The consumer to clear IDs/States for. + """ + pass + + +class ModelDeltaTrackerTrec(ModelDeltaTracker): + r""" + + ModelDeltaTrackerTrec provides a way to track and retrieve unique IDs for supported modules, along with optional support for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during online training. Unique IDs or states can be retrieved by calling the get_delta() method. @@ -85,8 +149,6 @@ class ModelDeltaTracker: """ - DEFAULT_CONSUMER: str = "default" - def __init__( self, model: nn.Module, @@ -354,7 +416,7 @@ def get_latest(self) -> Dict[str, torch.Tensor]: return ret - def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are first keyed by submodule FQN. diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index b9bab879a..a46268a3b 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -1739,7 +1739,7 @@ def _test_id_mode( tracked_out.sum().backward() baseline_out.sum().backward() - delta_ids = dt.get_delta_ids() + delta_ids = dt.get_unique_ids() table_fqns = dt.fqn_to_feature_names().keys() @@ -2035,7 +2035,7 @@ def _test_multiple_consumer( unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta_ids(consumer=consumer) + delta_rows = dt.get_unique_ids(consumer=consumer) # Verify that the current batch index is correct unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1) From 94e1849a83d1cacd473e04b8afccfda6e5fa08ea Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Mon, 20 Oct 2025 21:09:04 -0700 Subject: [PATCH 6/6] Updating record_lookup function signature to accommodate future implementations (#3471) Summary: internal General Context: We are in the process of transition to a unified DeltaTracker and this is 4/n diffs representing changes towards the transition. Specific Context: Update record_lookup function signature to accommodate MRS DeltaTracker implementation Differential Revision: D80614980 --- torchrec/distributed/embedding.py | 2 +- torchrec/distributed/embedding_lookup.py | 20 +++++++++++-------- torchrec/distributed/embedding_types.py | 8 +++++--- torchrec/distributed/embeddingbag.py | 2 +- .../model_tracker/model_delta_tracker.py | 14 +++++++++---- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 49afcfe7f..d8554edea 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1587,7 +1587,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(self, features, embs) + self.post_lookup_tracker_fn(features, embs, self) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index fa8338b7f..9f3ce69c7 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -210,7 +210,7 @@ def __init__( self.grouped_configs = grouped_configs # Model tracker function to tracker optimizer state self.optim_state_tracker_fn: Optional[ - Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] ] = None def _create_embedding_kernel( @@ -325,7 +325,7 @@ def forward( # Model tracker optimizer state function, will only be set called # when model tracker is configured to track optimizer state if self.optim_state_tracker_fn is not None: - self.optim_state_tracker_fn(emb_op, features, lookup) + self.optim_state_tracker_fn(features, lookup, emb_op) return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) @@ -432,13 +432,15 @@ def purge(self) -> None: def register_optim_state_tracker_fn( self, - record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + ], ) -> None: """ Model tracker function to tracker optimizer state Args: - record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done. """ self.optim_state_tracker_fn = record_fn @@ -544,7 +546,7 @@ def __init__( ) # Model tracker function to tracker optimizer state self.optim_state_tracker_fn: Optional[ - Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] ] = None def _create_embedding_kernel( @@ -710,7 +712,7 @@ def forward( # Model tracker optimizer state function, will only be set called # when model tracker is configured to track optimizer state if self.optim_state_tracker_fn is not None: - self.optim_state_tracker_fn(emb_op, features, lookup) + self.optim_state_tracker_fn(features, lookup, emb_op) if features.variable_stride_per_key() and len(self._emb_modules) > 1: stride_per_rank_per_key = list( @@ -845,13 +847,15 @@ def purge(self) -> None: def register_optim_state_tracker_fn( self, - record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + ], ) -> None: """ Model tracker function to tracker optimizer state Args: - record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done. """ self.optim_state_tracker_fn = record_fn diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index fbc17b408..d0a5ef920 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,7 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -444,14 +444,16 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, - record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + ], ) -> None: """ Register a function to be called after lookup is done. This is used for tracking the lookup results and optimizer states. Args: - record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor,Optional[nn.Module]], None]): A custom record function to be called after lookup is done. """ if self.post_lookup_tracker_fn is not None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 5b3b846e1..fd6117884 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(self, features, embs) + self.post_lookup_tracker_fn(features, embs, self) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 2d1e9af97..cbd871a04 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -79,7 +79,10 @@ class ModelDeltaTracker(ABC): @abstractmethod def record_lookup( - self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -233,7 +236,10 @@ def trigger_compaction(self) -> None: self.curr_compact_index = end_idx def record_lookup( - self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -258,12 +264,12 @@ def record_lookup( self.record_embeddings(kjt, states) # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. elif self._mode == TrackingMode.MOMENTUM_LAST: - self.record_momentum(emb_module, kjt) + self.record_momentum(none_throws(emb_module), kjt) elif ( self._mode == TrackingMode.MOMENTUM_DIFF or self._mode == TrackingMode.ROWWISE_ADAGRAD ): - self.record_rowwise_optim_state(emb_module, kjt) + self.record_rowwise_optim_state(none_throws(emb_module), kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported")