diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 673fbccae..49afcfe7f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1587,7 +1587,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 891dd0b02..fa8338b7f 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -10,7 +10,7 @@ import logging from abc import ABC from collections import OrderedDict -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -208,6 +208,10 @@ def __init__( ) self.grouped_configs = grouped_configs + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -315,7 +319,13 @@ def forward( self._feature_splits, ) for emb_op, features in zip(self._emb_modules, features_by_group): - embeddings.append(emb_op(features).view(-1)) + lookup = emb_op(features).view(-1) + embeddings.append(lookup) + + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) @@ -420,6 +430,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]): """ @@ -519,6 +542,10 @@ def __init__( if scale_weight_gradients and get_gradient_division() else 1 ) + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] + ] = None def _create_embedding_kernel( self, @@ -678,7 +705,12 @@ def forward( features._weights, self._scale_gradient_factor ) - embeddings.append(emb_op(features)) + lookup = emb_op(features) + embeddings.append(lookup) + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(emb_op, features, lookup) if features.variable_stride_per_key() and len(self._emb_modules) > 1: stride_per_rank_per_key = list( @@ -811,6 +843,19 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class MetaInferGroupedEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 5434a3203..fbc17b408 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,7 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -444,14 +444,14 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None], ) -> None: """ Register a function to be called after lookup is done. This is used for tracking the lookup results and optimizer states. Args: - record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. """ if self.post_lookup_tracker_fn is not None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 7c822051d..5b3b846e1 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(self, features, embs) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 7c9c45824..deddf4b0a 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -466,14 +466,14 @@ def get_model_tracker(self) -> ModelDeltaTracker: ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." return self.model_delta_tracker - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_unique(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: """ Returns the delta rows for the given consumer. """ assert ( self.model_delta_tracker is not None ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." - return self.model_delta_tracker.get_delta(consumer) + return self.model_delta_tracker.get_unique(consumer) def sparse_grad_parameter_names( self, destination: Optional[List[str]] = None, prefix: str = "" diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index 34ea8c88f..404bd1d83 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -7,6 +7,7 @@ # pyre-strict +from abc import ABC, abstractmethod from bisect import bisect_left from typing import Dict, List, Optional @@ -67,34 +68,106 @@ def _compute_unique_rows( return DeltaRows(ids=unique_ids, states=unique_states) -class DeltaStore: +class DeltaStore(ABC): """ - DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across - various batches during training, designed to be used with TorchRecs ModelDeltaTracker. - It maintains a CUDA in-memory representation of requested ids and embeddings/states, + DeltaStore is an abstract base class that defines the interface for storing and managing + local delta (row) updates for embeddings/states across various batches during training. + + Implementations should maintain a representation of requested ids and embeddings/states, providing a way to compact and get delta updates for each embedding table. The class supports different embedding update modes (NONE, FIRST, LAST) to determine how to handle duplicate ids when compacting or retrieving embeddings. + """ + + @abstractmethod + def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: + pass + + @abstractmethod + def append( + self, + batch_idx: int, + fqn: str, + ids: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + """ + Append a batch of ids and states to the store for a specific table. + + Args: + batch_idx: The batch index + table_fqn: The fully qualified name of the table + ids: The tensor of ids to append + states: Optional tensor of states to append + """ + pass + + @abstractmethod + def delete(self, up_to_idx: Optional[int] = None) -> None: + """ + Delete all idx from the store up to `up_to_idx` + + Args: + up_to_idx: Optional index up to which to delete lookups + """ + pass + + @abstractmethod + def compact(self, start_idx: int, end_idx: int) -> None: + """ + Compact (ids, embeddings) in batch index range from start_idx to end_idx. + + Args: + start_idx: The starting batch index + end_idx: The ending batch index + """ + pass + + @abstractmethod + def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + """ + Return all unique/delta ids per table from the Delta Store. + + Args: + from_idx: The batch index from which to get deltas + Returns: + A dictionary mapping table FQNs to their delta rows + """ + pass + + +class DeltaStoreTrec(DeltaStore): + """ + DeltaStoreTrec is a concrete implementation of DeltaStore that stores and manages + local delta (row) updates for embeddings/states across various batches during training, + designed to be used with TorchRecs ModelDeltaTracker. + + It maintains a CUDA in-memory representation of requested ids and embeddings/states, + providing a way to compact and get delta updates for each embedding table. + + The class supports different embedding update modes (NONE, FIRST, LAST) to determine + how to handle duplicate ids when compacting or retrieving embeddings. """ def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: + super().__init__(embdUpdateMode) self.embdUpdateMode = embdUpdateMode self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} def append( self, batch_idx: int, - table_fqn: str, + fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], ) -> None: - table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, []) + table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) ) - self.per_fqn_lookups[table_fqn] = table_fqn_lookup + self.per_fqn_lookups[fqn] = table_fqn_lookup def delete(self, up_to_idx: Optional[int] = None) -> None: """ @@ -151,7 +224,7 @@ def compact(self, start_idx: int, end_idx: int) -> None: ) self.per_fqn_lookups = new_per_fqn_lookups - def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]: r""" Return all unique/delta ids per table from the Delta Store. """ diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 905bf7648..3f0a6541b 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -8,28 +8,56 @@ # pyre-strict import logging as logger from collections import Counter, OrderedDict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Tuple import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn +from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding + from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.embedding_lookup import ( + BatchedFusedEmbeddingBag, + GroupedEmbeddingsLookup, + GroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection -from torchrec.distributed.model_tracker.delta_store import DeltaStore +from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec from torchrec.distributed.model_tracker.types import ( DeltaRows, EmbdUpdateMode, TrackingMode, ) +from torchrec.distributed.utils import none_throws + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = { # Only IDs are tracked, no additional state is stored. TrackingMode.ID_ONLY: EmbdUpdateMode.NONE, # TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that - # the earliest embedding values are stored since the last checkpoint or snapshot. - # This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk(). + # the earliest embedding values are stored since the last checkpoint + # or snapshot. This mode is used for computing topk delta rows, which + # is currently achieved by running (new_emb - old_emb).norm().topk(). TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST, + # TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that + # the most recent momentum values—capturing the accumulated gradient + # direction and magnitude—are stored since the last batch. + # This mode supports approximate top-k delta-row selection, can be + # obtained by running momentum.norm().topk(). + TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST, + # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row. + # Within each publishing interval, we track the starting value of this running + # sum on all used rows and then do a lookup when ``get_delta`` is called to query + # the latest sum. Then we can compute the delta of the two values and return them + # together with the row ids. + TrackingMode.MOMENTUM_DIFF: EmbdUpdateMode.FIRST, + # The same as MOMENTUM_DIFF. Adding for backward compatibility. + TrackingMode.ROWWISE_ADAGRAD: EmbdUpdateMode.FIRST, } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. @@ -87,13 +115,14 @@ def __init__( # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} self.feature_to_fqn: Dict[str, str] = {} # Generate the mapping from FQN to feature names. self.fqn_to_feature_names() # Validate is the mode is supported for the given module and initialize tracker functions self._validate_and_init_tracker_fns() - self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode]) + self.store: DeltaStoreTrec = DeltaStoreTrec(UPDATE_MODE_MAP[self._mode]) # Mapping feature name to corresponding FQNs. This is used for retrieving # the FQN associated with a given feature name in record_lookup(). @@ -141,7 +170,9 @@ def trigger_compaction(self) -> None: # Update the current compact index to the end index to avoid duplicate compaction. self.curr_compact_index = end_idx - def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: + def record_lookup( + self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor + ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -152,6 +183,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode). Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. """ @@ -162,7 +194,14 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch. elif self._mode == TrackingMode.EMBEDDING: self.record_embeddings(kjt, states) - + # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. + elif self._mode == TrackingMode.MOMENTUM_LAST: + self.record_momentum(emb_module, kjt) + elif ( + self._mode == TrackingMode.MOMENTUM_DIFF + or self._mode == TrackingMode.ROWWISE_ADAGRAD + ): + self.record_rowwise_optim_state(emb_module, kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported") @@ -183,7 +222,7 @@ def record_ids(self, kjt: KeyedJaggedTensor) -> None: for table_fqn, ids_list in per_table_ids.items(): self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=torch.cat(ids_list), states=None, ) @@ -223,11 +262,98 @@ def record_embeddings( for table_fqn, ids_list in per_table_ids.items(): self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=torch.cat(ids_list), states=torch.cat(per_table_emb[table_fqn]), ) + def record_momentum( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + # FIXME: this is the momentum from last iteration, use momentum from current iter + # for correctness. + # pyre-ignore Undefined attribute [16]: + momentum = emb_module._emb_module.momentum1_dev + # FIXME: support multiple tables per group, information can be extracted from + # module._config (i.e., GroupedEmbeddingConfig) + # pyre-ignore Undefined attribute [16]: + states = momentum.view(-1, emb_module._config.embedding_dims()[0])[ + kjt.values() + ].norm(dim=1) + + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def record_rowwise_optim_state( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + opt_states: List[List[torch.Tensor]] = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `split_optimizer_states`. + emb_module._emb_module.split_optimizer_states() + ) + proxy: torch.Tensor = torch.cat([state[0] for state in opt_states]) + states = proxy[kjt.values()] + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def get_latest(self) -> Dict[str, torch.Tensor]: + ret: Dict[str, torch.Tensor] = {} + for module in self.tracked_modules.values(): + # pyre-fixme[29]: + for lookup in module._lookups: + for embs_module in lookup._emb_modules: + assert isinstance( + embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding) + ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but {type(embs_module)} found" + tbe = embs_module._emb_module + + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + table_names = [t.name for t in embs_module._config.embedding_tables] + opt_states = tbe.split_optimizer_states() + assert len(table_names) == len(opt_states) + + for i, table_name in enumerate(table_names): + emb_fqn = self.table_to_fqn[table_name] + table_state = opt_states[i][0] + assert ( + emb_fqn not in ret + ), f"a table with {emb_fqn} already exists" + ret[emb_fqn] = table_state + + return ret + def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are @@ -236,10 +362,16 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso Args: consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer. """ - per_table_delta_rows = self.get_delta(consumer) + per_table_delta_rows = self.get_unique(consumer) return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()} - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, DeltaRows]: """ Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN. @@ -258,12 +390,23 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: # and index_start could be equal to index_end, in which case we should not compact again. if index_start < index_end: self.compact(index_start, index_end) - tracker_rows = self.store.get_delta( + tracker_rows = self.store.get_unique( from_idx=self.per_consumer_batch_idx[consumer] ) self.per_consumer_batch_idx[consumer] = index_end if self._delete_on_read: self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + if self._mode in (TrackingMode.MOMENTUM_DIFF, TrackingMode.ROWWISE_ADAGRAD): + square_sum_map = self.get_latest() + for fqn, rows in tracker_rows.items(): + assert ( + fqn in square_sum_map + ), f"{fqn} not found in {square_sum_map.keys()}" + # pyre-fixme[58]: `-` is not supported for operand types `Tensor` + # and `Optional[Tensor]`. + rows.states = square_sum_map[fqn][rows.ids] - rows.states + return tracker_rows def get_tracked_modules(self) -> Dict[str, nn.Module]: @@ -280,7 +423,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: return self._fqn_to_feature_map table_to_feature_names: Dict[str, List[str]] = OrderedDict() - table_to_fqn: Dict[str, str] = OrderedDict() for fqn, named_module in self._model.named_modules(): split_fqn = fqn.split(".") # Skipping partial FQNs present in fqns_to_skip @@ -306,13 +448,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: # will incorrectly match fqn with all the table names that have the same prefix if table_name in split_fqn: embedding_fqn = self._clean_fqn_fn(fqn) - if table_name in table_to_fqn: + if table_name in self.table_to_fqn: # Sanity check for validating that we don't have more then one table mapping to same fqn. logger.warning( - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" ) - table_to_fqn[table_name] = embedding_fqn - logger.info(f"Table to fqn: {table_to_fqn}") + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") flatten_names = [ name for names in table_to_feature_names.values() for name in names ] @@ -325,15 +467,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() for table_name in table_to_feature_names: - if table_name not in table_to_fqn: + if table_name not in self.table_to_fqn: # This is likely unexpected, where we can't locate the FQN associated with this table. logger.warning( - f"Table {table_name} not found in {table_to_fqn}, skipping" + f"Table {table_name} not found in {self.table_to_fqn}, skipping" ) continue - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ - table_name - ] + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) self._fqn_to_feature_map = fqn_to_feature_names return fqn_to_feature_names @@ -380,13 +522,49 @@ def _clean_fqn_fn(self, fqn: str) -> str: def _validate_and_init_tracker_fns(self) -> None: "To validate the mode is supported for the given module" for module in self.tracked_modules.values(): + # EMBEDDING mode is only supported for ShardedEmbeddingCollection assert not ( isinstance(module, ShardedEmbeddingBagCollection) and self._mode == TrackingMode.EMBEDDING ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings." - # register post lookup function - # pyre-ignore[29] - module.register_post_lookup_tracker_fn(self.record_lookup) + + if ( + self._mode == TrackingMode.ID_ONLY + or self._mode == TrackingMode.EMBEDDING + ): + # register post lookup function + # pyre-ignore[29] + module.register_post_lookup_tracker_fn(self.record_lookup) + elif self._mode == TrackingMode.MOMENTUM_LAST: + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + elif ( + self._mode == TrackingMode.ROWWISE_ADAGRAD + or self._mode == TrackingMode.MOMENTUM_DIFF + ): + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) and all( + # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM + for emb in lookup._emb_modules + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + else: + raise NotImplementedError( + f"Tracking mode {self._mode} is not supported" + ) # register auto compaction function at odist if self._auto_compact: # pyre-ignore[29] diff --git a/torchrec/distributed/model_tracker/tests/test_delta_store.py b/torchrec/distributed/model_tracker/tests/test_delta_store.py index 3401089fd..a8dc40d50 100644 --- a/torchrec/distributed/model_tracker/tests/test_delta_store.py +++ b/torchrec/distributed/model_tracker/tests/test_delta_store.py @@ -15,7 +15,7 @@ from parameterized import parameterized from torchrec.distributed.model_tracker.delta_store import ( _compute_unique_rows, - DeltaStore, + DeltaStoreTrec, ) from torchrec.distributed.model_tracker.types import ( DeltaRows, @@ -24,7 +24,7 @@ ) -class DeltaStoreTest(unittest.TestCase): +class DeltaStoreTrecTest(unittest.TestCase): # pyre-fixme[2]: Parameter must be annotated. def __init__(self, methodName="runTest") -> None: super().__init__(methodName) @@ -188,12 +188,12 @@ class AppendDeleteTestParams: def test_append_and_delete( self, _test_name: str, test_params: AppendDeleteTestParams ) -> None: - delta_store = DeltaStore() + delta_store = DeltaStoreTrec() for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): for lookup in lookup_list: delta_store.append( batch_idx=lookup.batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=lookup.ids, states=lookup.states, ) @@ -783,15 +783,15 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: """ Test the compact method of DeltaStore. """ - # Create a DeltaStore with the specified embdUpdateMode - delta_store = DeltaStore(embdUpdateMode=test_params.embdUpdateMode) + # Create a DeltaStoreTrec with the specified embdUpdateMode + delta_store = DeltaStoreTrec(embdUpdateMode=test_params.embdUpdateMode) # Populate the DeltaStore with the test lookups for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): for lookup in lookup_list: delta_store.append( batch_idx=lookup.batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=lookup.ids, states=lookup.states, ) @@ -806,8 +806,8 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: delta_store.compact( start_idx=test_params.start_idx, end_idx=test_params.end_idx ) - # Verify the result using get_delta method - delta_result = delta_store.get_delta() + # Verify the result using get_unique method + delta_result = delta_store.get_unique() # compare all fqns in the result for table_fqn, delta_rows in test_params.expected_delta.items(): diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index a92a9b286..b9bab879a 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -13,6 +13,9 @@ import torch import torchrec from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from parameterized import parameterized from torch import nn @@ -444,7 +447,7 @@ def test_fqn_to_feature_names( ), ), ( - "get_delta", + "get_unique", ModelDeltaTrackerInputTestParams( embedding_config_type=EmbeddingConfig, embedding_tables=[ @@ -461,7 +464,7 @@ def test_fqn_to_feature_names( model_tracker_config=ModelTrackerConfig(), ), TrackerNotInitOutputTestParams( - dmp_tracker_atter="get_delta", + dmp_tracker_atter="get_unique", ), ), ] @@ -1463,6 +1466,196 @@ def test_multiple_consumers( output_params=output_params, ) + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_momentum( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_momentum, + world_size=self.world_size, + test_params=test_params, + ) + + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_DIFF, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.ROWWISE_ADAGRAD, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_rowwise_adagrad( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_rowwise_adagrad, + world_size=self.world_size, + test_params=test_params, + ) + def _test_fqn_to_feature_names( rank: int, @@ -1650,7 +1843,7 @@ def _test_embedding_mode( tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) @@ -1771,7 +1964,7 @@ def _test_multiple_get( unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() # Verify that the current batch index is correct unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1) @@ -1859,3 +2052,124 @@ def _test_multiple_consumer( and returned.allclose(expected_ids), f"{i=}, {table_fqn=}, mismatch {returned=} vs {expected_ids=}", ) + + +def _test_duplication_with_momentum( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + """ + Test momentum tracking functionality in model delta tracker. + + Validates that the tracker correctly captures and stores momentum values from + optimizer states when using TrackingMode.MOMENTUM_LAST mode. + """ + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + ) + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + dt_model_opt.step() + baseline_opt.step() + + delta_rows = dt.get_unique() + for table_fqn in table_fqns_list: + ids = delta_rows[table_fqn].ids + states = none_throws(delta_rows[table_fqn].states) + + unittest.TestCase().assertTrue(states is not None) + unittest.TestCase().assertTrue(ids.numel() == states.numel()) + unittest.TestCase().assertTrue(bool((states != 0).all().item())) + + +def _test_duplication_with_rowwise_adagrad( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + optimizer_type=OptimType.EXACT_ROWWISE_ADAGRAD, + ) + + # read momemtum directly from the table + tbe: SplitTableBatchedEmbeddingBagsCodegen = ( + ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `ec`. + dt_model._dmp_wrapped_module.module.ec._lookups[0] + ._emb_modules[0] + .emb_module + ) + if test_params.embedding_config_type == EmbeddingConfig + else ( + dt_model._dmp_wrapped_module.module.ebc._lookups[0] # pyre-ignore + ._emb_modules[0] + .emb_module + ) + ) + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + start_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + + dt_model_opt.step() + baseline_opt.step() + + end_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + delta_rows = dt.get_unique() + table_fqn = table_fqns_list[0] + + ids = delta_rows[table_fqn].ids + tracked_momentum = none_throws(delta_rows[table_fqn].states) + unittest.TestCase().assertTrue(tracked_momentum is not None) + unittest.TestCase().assertTrue(ids.numel() == tracked_momentum.numel()) + unittest.TestCase().assertTrue(bool((tracked_momentum != 0).all().item())) + + expected_momentum = end_momentums[ids] - start_momentums[ids] + unittest.TestCase().assertTrue(tracked_momentum.allclose(expected_momentum)) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index cec95af91..43a1b9223 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -41,13 +41,21 @@ class TrackingMode(Enum): Tracking mode for ``ModelDeltaTracker``. Enums: - ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. - EMBEDDING: Tracks both row IDs and their corresponding embedding values, - enabling precise top-k result calculations. However, this option comes with increased memory usage. + ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. + EMBEDDING: Tracks both row IDs and their corresponding embedding values, + enabling precise top-k result calculations. However, this option comes + with increased memory usage. + MOMENTUM_LAST: Tracks both row IDs and their corresponding momentum values. This mode + supports approximate top-k delta-row selection. + MOMENTUM_DIFF: Tracks both row IDs and their corresponding momentum difference values. + ROWWISE_ADAGRAD: Tracks both row IDs and their corresponding rowwise adagrad states. """ ID_ONLY = "id_only" EMBEDDING = "embedding" + MOMENTUM_LAST = "momentum_last" + MOMENTUM_DIFF = "momentum_diff" + ROWWISE_ADAGRAD = "rowwise_adagrad" class EmbdUpdateMode(Enum):