diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 673fbccae..d8554edea 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1587,7 +1587,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(features, embs, self) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 891dd0b02..9f3ce69c7 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -10,7 +10,7 @@ import logging from abc import ABC from collections import OrderedDict -from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -208,6 +208,10 @@ def __init__( ) self.grouped_configs = grouped_configs + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] + ] = None def _create_embedding_kernel( self, @@ -315,7 +319,13 @@ def forward( self._feature_splits, ) for emb_op, features in zip(self._emb_modules, features_by_group): - embeddings.append(emb_op(features).view(-1)) + lookup = emb_op(features).view(-1) + embeddings.append(lookup) + + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(features, lookup, emb_op) return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor) @@ -420,6 +430,21 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[ + [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + ], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]): """ @@ -519,6 +544,10 @@ def __init__( if scale_weight_gradients and get_gradient_division() else 1 ) + # Model tracker function to tracker optimizer state + self.optim_state_tracker_fn: Optional[ + Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] + ] = None def _create_embedding_kernel( self, @@ -678,7 +707,12 @@ def forward( features._weights, self._scale_gradient_factor ) - embeddings.append(emb_op(features)) + lookup = emb_op(features) + embeddings.append(lookup) + # Model tracker optimizer state function, will only be set called + # when model tracker is configured to track optimizer state + if self.optim_state_tracker_fn is not None: + self.optim_state_tracker_fn(features, lookup, emb_op) if features.variable_stride_per_key() and len(self._emb_modules) > 1: stride_per_rank_per_key = list( @@ -811,6 +845,21 @@ def purge(self) -> None: # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. emb_module.purge() + def register_optim_state_tracker_fn( + self, + record_fn: Callable[ + [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + ], + ) -> None: + """ + Model tracker function to tracker optimizer state + + Args: + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done. + + """ + self.optim_state_tracker_fn = record_fn + class MetaInferGroupedEmbeddingsLookup( BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 5434a3203..d0a5ef920 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -391,7 +391,7 @@ def __init__( self._lookups: List[nn.Module] = [] self._output_dists: List[nn.Module] = [] self.post_lookup_tracker_fn: Optional[ - Callable[[KeyedJaggedTensor, torch.Tensor], None] + Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None] ] = None self.post_odist_tracker_fn: Optional[Callable[..., None]] = None @@ -444,14 +444,16 @@ def train(self, mode: bool = True): # pyre-ignore[3] def register_post_lookup_tracker_fn( self, - record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None], + record_fn: Callable[ + [KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None + ], ) -> None: """ Register a function to be called after lookup is done. This is used for tracking the lookup results and optimizer states. Args: - record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done. + record_fn (Callable[[KeyedJaggedTensor, torch.Tensor,Optional[nn.Module]], None]): A custom record function to be called after lookup is done. """ if self.post_lookup_tracker_fn is not None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 7c822051d..fd6117884 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1671,7 +1671,7 @@ def compute_and_output_dist( ): embs = lookup(features) if self.post_lookup_tracker_fn is not None: - self.post_lookup_tracker_fn(features, embs) + self.post_lookup_tracker_fn(features, embs, self) with maybe_annotate_embedding_event( EmbeddingEvent.OUTPUT_DIST, diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 7c9c45824..45f2288e9 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -29,8 +29,8 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size -from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker -from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig +from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec +from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import get_default_sharders @@ -50,6 +50,7 @@ append_prefix, copy_to_device, filter_state_dict, + none_throws, sharded_model_copy, ) from torchrec.optim.fused import FusedOptimizerModule @@ -293,7 +294,7 @@ def __init__( if init_data_parallel: self.init_data_parallel() - self.model_delta_tracker: Optional[ModelDeltaTracker] = ( + self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = ( self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module) if model_tracker_config is not None else None @@ -369,9 +370,9 @@ def _init_dmp(self, module: nn.Module) -> nn.Module: def _init_delta_tracker( self, model_tracker_config: ModelTrackerConfig, module: nn.Module - ) -> ModelDeltaTracker: + ) -> ModelDeltaTrackerTrec: # Init delta tracker if config is provided - return ModelDeltaTracker( + return ModelDeltaTrackerTrec( model=module, consumers=model_tracker_config.consumers, delete_on_read=model_tracker_config.delete_on_read, @@ -456,7 +457,20 @@ def init_parameters(module: nn.Module) -> None: module.apply(init_parameters) - def get_model_tracker(self) -> ModelDeltaTracker: + def init_torchrec_delta_tracker( + self, model_tracker_config: ModelTrackerConfig + ) -> ModelDeltaTrackerTrec: + """ + Initializes the model delta tracker if it doesn't exists. + """ + if self.model_delta_tracker is None: + self.model_delta_tracker = self._init_delta_tracker( + model_tracker_config, self._dmp_wrapped_module + ) + + return none_throws(self.model_delta_tracker) + + def get_model_tracker(self) -> ModelDeltaTrackerTrec: """ Returns the model tracker if it exists. """ @@ -466,14 +480,14 @@ def get_model_tracker(self) -> ModelDeltaTracker: ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." return self.model_delta_tracker - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_unique(self, consumer: Optional[str] = None) -> Dict[str, UniqueRows]: """ Returns the delta rows for the given consumer. """ assert ( self.model_delta_tracker is not None ), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init." - return self.model_delta_tracker.get_delta(consumer) + return self.model_delta_tracker.get_unique(consumer) def sparse_grad_parameter_names( self, destination: Optional[List[str]] = None, prefix: str = "" diff --git a/torchrec/distributed/model_tracker/__init__.py b/torchrec/distributed/model_tracker/__init__.py index ab7a47a33..2895e218a 100644 --- a/torchrec/distributed/model_tracker/__init__.py +++ b/torchrec/distributed/model_tracker/__init__.py @@ -27,9 +27,9 @@ SUPPORTED_MODULES, # noqa ) from torchrec.distributed.model_tracker.types import ( - DeltaRows, # noqa - EmbdUpdateMode, # noqa IndexedLookup, # noqa ModelTrackerConfig, # noqa TrackingMode, # noqa + UniqueRows, # noqa + UpdateMode, # noqa ) diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index 34ea8c88f..bd2ee1b27 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -7,14 +7,15 @@ # pyre-strict +from abc import ABC, abstractmethod from bisect import bisect_left from typing import Dict, List, Optional import torch from torchrec.distributed.model_tracker.types import ( - DeltaRows, - EmbdUpdateMode, IndexedLookup, + UniqueRows, + UpdateMode, ) from torchrec.distributed.utils import none_throws @@ -22,24 +23,24 @@ def _compute_unique_rows( ids: List[torch.Tensor], states: Optional[List[torch.Tensor]], - mode: EmbdUpdateMode, -) -> DeltaRows: + mode: UpdateMode, +) -> UniqueRows: r""" To calculate unique ids and embeddings """ - if mode == EmbdUpdateMode.NONE: - assert states is None, f"{mode=} == EmbdUpdateMode.NONE but received embeddings" + if mode == UpdateMode.NONE: + assert states is None, f"{mode=} == UpdateMode.NONE but received embeddings" unique_ids = torch.cat(ids).unique(return_inverse=False) - return DeltaRows(ids=unique_ids, states=None) + return UniqueRows(ids=unique_ids, states=None) else: assert ( states is not None - ), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings" + ), f"{mode=} != UpdateMode.NONE but received no embeddings" cat_ids = torch.cat(ids) cat_states = torch.cat(states) - if mode == EmbdUpdateMode.LAST: + if mode == UpdateMode.LAST: cat_ids = cat_ids.flip(dims=[0]) cat_states = cat_states.flip(dims=[0]) @@ -64,37 +65,109 @@ def _compute_unique_rows( # Use first occurrence indices to select corresponding embedding row. unique_states = cat_states[first_occurrence] - return DeltaRows(ids=unique_ids, states=unique_states) + return UniqueRows(ids=unique_ids, states=unique_states) -class DeltaStore: +class DeltaStore(ABC): """ - DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across - various batches during training, designed to be used with TorchRecs ModelDeltaTracker. - It maintains a CUDA in-memory representation of requested ids and embeddings/states, + DeltaStore is an abstract base class that defines the interface for storing and managing + local delta (row) updates for embeddings/states across various batches during training. + + Implementations should maintain a representation of requested ids and embeddings/states, providing a way to compact and get delta updates for each embedding table. The class supports different embedding update modes (NONE, FIRST, LAST) to determine how to handle duplicate ids when compacting or retrieving embeddings. + """ + + @abstractmethod + def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None: + pass + + @abstractmethod + def append( + self, + batch_idx: int, + fqn: str, + ids: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + """ + Append a batch of ids and states to the store for a specific table. + + Args: + batch_idx: The batch index + table_fqn: The fully qualified name of the table + ids: The tensor of ids to append + states: Optional tensor of states to append + """ + pass + + @abstractmethod + def delete(self, up_to_idx: Optional[int] = None) -> None: + """ + Delete all idx from the store up to `up_to_idx` + + Args: + up_to_idx: Optional index up to which to delete lookups + """ + pass + @abstractmethod + def compact(self, start_idx: int, end_idx: int) -> None: + """ + Compact (ids, embeddings) in batch index range from start_idx to end_idx. + + Args: + start_idx: The starting batch index + end_idx: The ending batch index + """ + pass + + @abstractmethod + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: + """ + Return all unique/delta ids per table from the Delta Store. + + Args: + from_idx: The batch index from which to get deltas + + Returns: + A dictionary mapping table FQNs to their delta rows + """ + pass + + +class DeltaStoreTrec(DeltaStore): + """ + DeltaStoreTrec is a concrete implementation of DeltaStore that stores and manages + local delta (row) updates for embeddings/states across various batches during training, + designed to be used with TorchRecs ModelDeltaTracker. + + It maintains a CUDA in-memory representation of requested ids and embeddings/states, + providing a way to compact and get delta updates for each embedding table. + + The class supports different embedding update modes (NONE, FIRST, LAST) to determine + how to handle duplicate ids when compacting or retrieving embeddings. """ - def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: - self.embdUpdateMode = embdUpdateMode + def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None: + super().__init__(updateMode) + self.updateMode = updateMode self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} def append( self, batch_idx: int, - table_fqn: str, + fqn: str, ids: torch.Tensor, states: Optional[torch.Tensor], ) -> None: - table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, []) + table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) ) - self.per_fqn_lookups[table_fqn] = table_fqn_lookup + self.per_fqn_lookups[fqn] = table_fqn_lookup def delete(self, up_to_idx: Optional[int] = None) -> None: """ @@ -132,11 +205,11 @@ def compact(self, start_idx: int, end_idx: int) -> None: ids = [lookup.ids for lookup in lookups_to_compact] states = ( [none_throws(lookup.states) for lookup in lookups_to_compact] - if self.embdUpdateMode != EmbdUpdateMode.NONE + if self.updateMode != UpdateMode.NONE else None ) delta_rows = _compute_unique_rows( - ids=ids, states=states, mode=self.embdUpdateMode + ids=ids, states=states, mode=self.updateMode ) new_per_fqn_lookups[table_fqn] = ( lookups[:index_l] @@ -151,12 +224,12 @@ def compact(self, start_idx: int, end_idx: int) -> None: ) self.per_fqn_lookups = new_per_fqn_lookups - def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: r""" Return all unique/delta ids per table from the Delta Store. """ - delta_per_table_fqn: Dict[str, DeltaRows] = {} + delta_per_table_fqn: Dict[str, UniqueRows] = {} for table_fqn, lookups in self.per_fqn_lookups.items(): compact_ids = [ lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx @@ -167,11 +240,11 @@ def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: for lookup in lookups if lookup.batch_idx >= from_idx ] - if self.embdUpdateMode != EmbdUpdateMode.NONE + if self.updateMode != UpdateMode.NONE else None ) delta_per_table_fqn[table_fqn] = _compute_unique_rows( - ids=compact_ids, states=compact_states, mode=self.embdUpdateMode + ids=compact_ids, states=compact_states, mode=self.updateMode ) return delta_per_table_fqn diff --git a/torchrec/distributed/model_tracker/model_delta_tracker.py b/torchrec/distributed/model_tracker/model_delta_tracker.py index 905bf7648..c3345e7c6 100644 --- a/torchrec/distributed/model_tracker/model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/model_delta_tracker.py @@ -7,39 +7,134 @@ # pyre-strict import logging as logger +from abc import ABC, abstractmethod from collections import Counter, OrderedDict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Tuple import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn +from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding + from torchrec.distributed.embedding import ShardedEmbeddingCollection +from torchrec.distributed.embedding_lookup import ( + BatchedFusedEmbeddingBag, + GroupedEmbeddingsLookup, + GroupedPooledEmbeddingsLookup, +) from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection -from torchrec.distributed.model_tracker.delta_store import DeltaStore +from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec from torchrec.distributed.model_tracker.types import ( - DeltaRows, - EmbdUpdateMode, TrackingMode, + UniqueRows, + UpdateMode, ) +from torchrec.distributed.utils import none_throws + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = { +UPDATE_MODE_MAP: Dict[TrackingMode, UpdateMode] = { # Only IDs are tracked, no additional state is stored. - TrackingMode.ID_ONLY: EmbdUpdateMode.NONE, - # TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that - # the earliest embedding values are stored since the last checkpoint or snapshot. - # This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk(). - TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST, + TrackingMode.ID_ONLY: UpdateMode.NONE, + # TrackingMode.EMBEDDING utilizes UpdateMode.FIRST to ensure that + # the earliest embedding values are stored since the last checkpoint + # or snapshot. This mode is used for computing topk delta rows, which + # is currently achieved by running (new_emb - old_emb).norm().topk(). + TrackingMode.EMBEDDING: UpdateMode.FIRST, + # TrackingMode.MOMENTUM utilizes UpdateMode.LAST to ensure that + # the most recent momentum values—capturing the accumulated gradient + # direction and magnitude—are stored since the last batch. + # This mode supports approximate top-k delta-row selection, can be + # obtained by running momentum.norm().topk(). + TrackingMode.MOMENTUM_LAST: UpdateMode.LAST, + # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row. + # Within each publishing interval, we track the starting value of this running + # sum on all used rows and then do a lookup when ``get_delta`` is called to query + # the latest sum. Then we can compute the delta of the two values and return them + # together with the row ids. + TrackingMode.MOMENTUM_DIFF: UpdateMode.FIRST, + # The same as MOMENTUM_DIFF. Adding for backward compatibility. + TrackingMode.ROWWISE_ADAGRAD: UpdateMode.FIRST, } # Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection) -class ModelDeltaTracker: +class ModelDeltaTracker(ABC): r""" - ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support + Abstract base class for ModelDeltaTracker that provides a way to track and retrieve unique IDs for supported modules, + along with optional support for tracking corresponding embeddings or states. This is useful for identifying and + retrieving the latest delta or unique rows for a given model, which can help compute topk or to stream updated + embeddings from predictors to trainers during online training. + + """ + + DEFAULT_CONSUMER: str = "default" + + @abstractmethod + def record_lookup( + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, + ) -> None: + """ + Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. + + Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. + kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. + states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. + """ + pass + + @abstractmethod + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + """ + Return a dictionary of hit local IDs for each sparse feature. + + Args: + consumer (str, optional): The consumer to retrieve unique IDs for. + """ + pass + + @abstractmethod + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, UniqueRows]: + """ + Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. + + Args: + consumer (str, optional): The consumer to retrieve delta values for. + """ + pass + + @abstractmethod + def clear(self, consumer: Optional[str] = None) -> None: + """ + Clear tracked IDs for a given consumer. + + Args: + consumer (str, optional): The consumer to clear IDs/States for. + """ + pass + + +class ModelDeltaTrackerTrec(ModelDeltaTracker): + r""" + + ModelDeltaTrackerTrec provides a way to track and retrieve unique IDs for supported modules, along with optional support for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during online training. Unique IDs or states can be retrieved by calling the get_delta() method. @@ -57,8 +152,6 @@ class ModelDeltaTracker: """ - DEFAULT_CONSUMER: str = "default" - def __init__( self, model: nn.Module, @@ -87,13 +180,14 @@ def __init__( # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection self.tracked_modules: Dict[str, nn.Module] = {} + self.table_to_fqn: Dict[str, str] = {} self.feature_to_fqn: Dict[str, str] = {} # Generate the mapping from FQN to feature names. self.fqn_to_feature_names() # Validate is the mode is supported for the given module and initialize tracker functions self._validate_and_init_tracker_fns() - self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode]) + self.store: DeltaStoreTrec = DeltaStoreTrec(UPDATE_MODE_MAP[self._mode]) # Mapping feature name to corresponding FQNs. This is used for retrieving # the FQN associated with a given feature name in record_lookup(). @@ -141,7 +235,12 @@ def trigger_compaction(self) -> None: # Update the current compact index to the end index to avoid duplicate compaction. self.curr_compact_index = end_idx - def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: + def record_lookup( + self, + kjt: KeyedJaggedTensor, + states: torch.Tensor, + emb_module: Optional[nn.Module] = None, + ) -> None: """ Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states. @@ -152,6 +251,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: (in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode). Args: + emb_module (nn.Module): The embedding module in which the lookup was performed. kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record. states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt. """ @@ -162,7 +262,14 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: # In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch. elif self._mode == TrackingMode.EMBEDDING: self.record_embeddings(kjt, states) - + # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch. + elif self._mode == TrackingMode.MOMENTUM_LAST: + self.record_momentum(none_throws(emb_module), kjt) + elif ( + self._mode == TrackingMode.MOMENTUM_DIFF + or self._mode == TrackingMode.ROWWISE_ADAGRAD + ): + self.record_rowwise_optim_state(none_throws(emb_module), kjt) else: raise NotImplementedError(f"Tracking mode {self._mode} is not supported") @@ -183,7 +290,7 @@ def record_ids(self, kjt: KeyedJaggedTensor) -> None: for table_fqn, ids_list in per_table_ids.items(): self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=torch.cat(ids_list), states=None, ) @@ -223,12 +330,99 @@ def record_embeddings( for table_fqn, ids_list in per_table_ids.items(): self.store.append( batch_idx=self.curr_batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=torch.cat(ids_list), states=torch.cat(per_table_emb[table_fqn]), ) - def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: + def record_momentum( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + # FIXME: this is the momentum from last iteration, use momentum from current iter + # for correctness. + # pyre-ignore Undefined attribute [16]: + momentum = emb_module._emb_module.momentum1_dev + # FIXME: support multiple tables per group, information can be extracted from + # module._config (i.e., GroupedEmbeddingConfig) + # pyre-ignore Undefined attribute [16]: + states = momentum.view(-1, emb_module._config.embedding_dims()[0])[ + kjt.values() + ].norm(dim=1) + + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def record_rowwise_optim_state( + self, + emb_module: nn.Module, + kjt: KeyedJaggedTensor, + ) -> None: + opt_states: List[List[torch.Tensor]] = ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `split_optimizer_states`. + emb_module._emb_module.split_optimizer_states() + ) + proxy: torch.Tensor = torch.cat([state[0] for state in opt_states]) + states = proxy[kjt.values()] + assert ( + kjt.values().numel() == states.numel() + ), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} " + offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(kjt.length_per_key(), dtype=torch.int64) + ) + for i, key in enumerate(kjt.keys()): + fqn = self.feature_to_fqn[key] + per_key_states = states[offsets[i] : offsets[i + 1]] + self.store.append( + batch_idx=self.curr_batch_idx, + fqn=fqn, + ids=kjt[key].values(), + states=per_key_states, + ) + + def get_latest(self) -> Dict[str, torch.Tensor]: + ret: Dict[str, torch.Tensor] = {} + for module in self.tracked_modules.values(): + # pyre-fixme[29]: + for lookup in module._lookups: + for embs_module in lookup._emb_modules: + assert isinstance( + embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding) + ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but {type(embs_module)} found" + tbe = embs_module._emb_module + + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + table_names = [t.name for t in embs_module._config.embedding_tables] + opt_states = tbe.split_optimizer_states() + assert len(table_names) == len(opt_states) + + for i, table_name in enumerate(table_names): + emb_fqn = self.table_to_fqn[table_name] + table_state = opt_states[i][0] + assert ( + emb_fqn not in ret + ), f"a table with {emb_fqn} already exists" + ret[emb_fqn] = table_state + + return ret + + def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]: """ Return a dictionary of hit local IDs for each sparse feature. Ids are first keyed by submodule FQN. @@ -236,10 +430,16 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso Args: consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer. """ - per_table_delta_rows = self.get_delta(consumer) + per_table_delta_rows = self.get_unique(consumer) return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()} - def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: + def get_unique( + self, + consumer: Optional[str] = None, + top_percentage: Optional[float] = 1.0, + per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None, + sorted_by_indices: Optional[bool] = True, + ) -> Dict[str, UniqueRows]: """ Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN. @@ -258,12 +458,22 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: # and index_start could be equal to index_end, in which case we should not compact again. if index_start < index_end: self.compact(index_start, index_end) - tracker_rows = self.store.get_delta( + tracker_rows = self.store.get_unique( from_idx=self.per_consumer_batch_idx[consumer] ) self.per_consumer_batch_idx[consumer] = index_end if self._delete_on_read: self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) + + if self._mode in (TrackingMode.MOMENTUM_DIFF, TrackingMode.ROWWISE_ADAGRAD): + square_sum_map = self.get_latest() + for fqn, rows in tracker_rows.items(): + assert ( + fqn in square_sum_map + ), f"{fqn} not found in {square_sum_map.keys()}" + # pyre-fixme[58]: `-` is not supported for operand types `Tensor` and `Optional[Tensor]`. + rows.states = square_sum_map[fqn][rows.ids] - rows.states + return tracker_rows def get_tracked_modules(self) -> Dict[str, nn.Module]: @@ -280,7 +490,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: return self._fqn_to_feature_map table_to_feature_names: Dict[str, List[str]] = OrderedDict() - table_to_fqn: Dict[str, str] = OrderedDict() for fqn, named_module in self._model.named_modules(): split_fqn = fqn.split(".") # Skipping partial FQNs present in fqns_to_skip @@ -306,13 +515,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: # will incorrectly match fqn with all the table names that have the same prefix if table_name in split_fqn: embedding_fqn = self._clean_fqn_fn(fqn) - if table_name in table_to_fqn: + if table_name in self.table_to_fqn: # Sanity check for validating that we don't have more then one table mapping to same fqn. logger.warning( - f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" + f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}" ) - table_to_fqn[table_name] = embedding_fqn - logger.info(f"Table to fqn: {table_to_fqn}") + self.table_to_fqn[table_name] = embedding_fqn + logger.info(f"Table to fqn: {self.table_to_fqn}") flatten_names = [ name for names in table_to_feature_names.values() for name in names ] @@ -325,15 +534,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]: fqn_to_feature_names: Dict[str, List[str]] = OrderedDict() for table_name in table_to_feature_names: - if table_name not in table_to_fqn: + if table_name not in self.table_to_fqn: # This is likely unexpected, where we can't locate the FQN associated with this table. logger.warning( - f"Table {table_name} not found in {table_to_fqn}, skipping" + f"Table {table_name} not found in {self.table_to_fqn}, skipping" ) continue - fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[ - table_name - ] + fqn_to_feature_names[self.table_to_fqn[table_name]] = ( + table_to_feature_names[table_name] + ) self._fqn_to_feature_map = fqn_to_feature_names return fqn_to_feature_names @@ -380,13 +589,49 @@ def _clean_fqn_fn(self, fqn: str) -> str: def _validate_and_init_tracker_fns(self) -> None: "To validate the mode is supported for the given module" for module in self.tracked_modules.values(): + # EMBEDDING mode is only supported for ShardedEmbeddingCollection assert not ( isinstance(module, ShardedEmbeddingBagCollection) and self._mode == TrackingMode.EMBEDDING ), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings." - # register post lookup function - # pyre-ignore[29] - module.register_post_lookup_tracker_fn(self.record_lookup) + + if ( + self._mode == TrackingMode.ID_ONLY + or self._mode == TrackingMode.EMBEDDING + ): + # register post lookup function + # pyre-ignore[29] + module.register_post_lookup_tracker_fn(self.record_lookup) + elif self._mode == TrackingMode.MOMENTUM_LAST: + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + elif ( + self._mode == TrackingMode.ROWWISE_ADAGRAD + or self._mode == TrackingMode.MOMENTUM_DIFF + ): + # pyre-ignore[29]: + for lookup in module._lookups: + assert isinstance( + lookup, + (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), + ) and all( + # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD + # pyre-ignore[16]: + or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM + for emb in lookup._emb_modules + ) + lookup.register_optim_state_tracker_fn(self.record_lookup) + else: + raise NotImplementedError( + f"Tracking mode {self._mode} is not supported" + ) # register auto compaction function at odist if self._auto_compact: # pyre-ignore[29] diff --git a/torchrec/distributed/model_tracker/tests/test_delta_store.py b/torchrec/distributed/model_tracker/tests/test_delta_store.py index 3401089fd..807aac76d 100644 --- a/torchrec/distributed/model_tracker/tests/test_delta_store.py +++ b/torchrec/distributed/model_tracker/tests/test_delta_store.py @@ -15,16 +15,16 @@ from parameterized import parameterized from torchrec.distributed.model_tracker.delta_store import ( _compute_unique_rows, - DeltaStore, + DeltaStoreTrec, ) from torchrec.distributed.model_tracker.types import ( - DeltaRows, - EmbdUpdateMode, IndexedLookup, + UniqueRows, + UpdateMode, ) -class DeltaStoreTest(unittest.TestCase): +class DeltaStoreTrecTest(unittest.TestCase): # pyre-fixme[2]: Parameter must be annotated. def __init__(self, methodName="runTest") -> None: super().__init__(methodName) @@ -188,12 +188,12 @@ class AppendDeleteTestParams: def test_append_and_delete( self, _test_name: str, test_params: AppendDeleteTestParams ) -> None: - delta_store = DeltaStore() + delta_store = DeltaStoreTrec() for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): for lookup in lookup_list: delta_store.append( batch_idx=lookup.batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=lookup.ids, states=lookup.states, ) @@ -214,14 +214,14 @@ class ComputeTestParams: # input parameters ids: List[torch.Tensor] embeddings: Optional[List[torch.Tensor]] - embdUpdateMode: EmbdUpdateMode + updateMode: UpdateMode # expected output parameters - expected_output: DeltaRows + expected_output: UniqueRows expect_assert: bool @parameterized.expand( [ - # test cases for EmbdUpdateMode.NONE + # test cases for UpdateMode.NONE ( "unique_ids", ComputeTestParams( @@ -230,8 +230,8 @@ class ComputeTestParams: torch.tensor([6, 7, 8, 9, 10]), ], embeddings=None, - embdUpdateMode=EmbdUpdateMode.NONE, - expected_output=DeltaRows( + updateMode=UpdateMode.NONE, + expected_output=UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), states=None, ), @@ -246,15 +246,15 @@ class ComputeTestParams: torch.tensor([2, 10, 8, 4, 9, 7]), ], embeddings=None, - embdUpdateMode=EmbdUpdateMode.NONE, - expected_output=DeltaRows( + updateMode=UpdateMode.NONE, + expected_output=UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), states=None, ), expect_assert=False, ), ), - # test case for EmbdUpdateMode.NONE with embeddings (should assert) + # test case for UpdateMode.NONE with embeddings (should assert) ( "none_mode_with_embeddings", ComputeTestParams( @@ -266,15 +266,15 @@ class ComputeTestParams: torch.tensor([[1.0], [2.0], [3.0]]), torch.tensor([[4.0], [5.0], [6.0]]), ], - embdUpdateMode=EmbdUpdateMode.NONE, - expected_output=DeltaRows( + updateMode=UpdateMode.NONE, + expected_output=UniqueRows( ids=torch.tensor([]), states=None, ), expect_assert=True, ), ), - # test cases for EmbdUpdateMode.FIRST + # test cases for UpdateMode.FIRST ( "first_mode_without_embeddings", ComputeTestParams( @@ -283,8 +283,8 @@ class ComputeTestParams: torch.tensor([4, 5, 6]), ], embeddings=None, - embdUpdateMode=EmbdUpdateMode.FIRST, - expected_output=DeltaRows( + updateMode=UpdateMode.FIRST, + expected_output=UniqueRows( ids=torch.tensor([]), states=None, ), @@ -302,8 +302,8 @@ class ComputeTestParams: torch.tensor([[1.0], [2.0], [3.0]]), torch.tensor([[4.0], [5.0], [6.0]]), ], - embdUpdateMode=EmbdUpdateMode.FIRST, - expected_output=DeltaRows( + updateMode=UpdateMode.FIRST, + expected_output=UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6]), states=torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]), ), @@ -321,8 +321,8 @@ class ComputeTestParams: torch.tensor([[40.0], [10.0], [30.0], [60.0], [50.0], [20.0]]), torch.tensor([[25.0], [100.0], [80.0], [45.0], [90.0], [70.0]]), ], - embdUpdateMode=EmbdUpdateMode.FIRST, - expected_output=DeltaRows( + updateMode=UpdateMode.FIRST, + expected_output=UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), # First occurrence of each ID is kept states=torch.tensor( @@ -343,7 +343,7 @@ class ComputeTestParams: expect_assert=False, ), ), - # test cases for EmbdUpdateMode.LAST + # test cases for UpdateMode.LAST ( "last_mode_without_embeddings", ComputeTestParams( @@ -352,8 +352,8 @@ class ComputeTestParams: torch.tensor([4, 5, 6]), ], embeddings=None, - embdUpdateMode=EmbdUpdateMode.LAST, - expected_output=DeltaRows( + updateMode=UpdateMode.LAST, + expected_output=UniqueRows( ids=torch.tensor([]), states=None, ), @@ -371,8 +371,8 @@ class ComputeTestParams: torch.tensor([[1.0], [2.0], [3.0]]), torch.tensor([[4.0], [5.0], [6.0]]), ], - embdUpdateMode=EmbdUpdateMode.LAST, - expected_output=DeltaRows( + updateMode=UpdateMode.LAST, + expected_output=UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6]), states=torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]), ), @@ -390,8 +390,8 @@ class ComputeTestParams: torch.tensor([[40.0], [10.0], [30.0], [60.0], [50.0], [20.0]]), torch.tensor([[25.0], [100.0], [80.0], [45.0], [90.0], [70.0]]), ], - embdUpdateMode=EmbdUpdateMode.LAST, - expected_output=DeltaRows( + updateMode=UpdateMode.LAST, + expected_output=UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), # Last occurrence of each ID is kept states=torch.tensor( @@ -421,12 +421,12 @@ def test_compute_unique_rows( # If we expect an assertion error, check that it's raised with self.assertRaises(AssertionError): _compute_unique_rows( - test_params.ids, test_params.embeddings, test_params.embdUpdateMode + test_params.ids, test_params.embeddings, test_params.updateMode ) else: # Otherwise, proceed with the normal test result = _compute_unique_rows( - test_params.ids, test_params.embeddings, test_params.embdUpdateMode + test_params.ids, test_params.embeddings, test_params.updateMode ) self.assertTrue(torch.equal(result.ids, test_params.expected_output.ids)) @@ -444,21 +444,21 @@ def test_compute_unique_rows( @dataclass class CompactTestParams: # input parameters - embdUpdateMode: EmbdUpdateMode + updateMode: UpdateMode table_fqn_to_lookups: Dict[str, List[IndexedLookup]] start_idx: int end_idx: int # expected output parameters - expected_delta: Dict[str, DeltaRows] + expected_delta: Dict[str, UniqueRows] expect_assert: bool = False @parameterized.expand( [ - # Test case for compaction with EmbdUpdateMode.NONE + # Test case for compaction with UpdateMode.NONE ( "empty_lookups", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.NONE, + updateMode=UpdateMode.NONE, table_fqn_to_lookups={}, start_idx=1, end_idx=5, @@ -468,7 +468,7 @@ class CompactTestParams: ( "single_lookup_no_compaction", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.NONE, + updateMode=UpdateMode.NONE, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -481,7 +481,7 @@ class CompactTestParams: start_idx=1, end_idx=5, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3]), states=None, ), @@ -491,7 +491,7 @@ class CompactTestParams: ( "multi_lookup_all_unique", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.NONE, + updateMode=UpdateMode.NONE, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -514,7 +514,7 @@ class CompactTestParams: start_idx=1, end_idx=3, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), states=None, ), @@ -524,7 +524,7 @@ class CompactTestParams: ( "multi_lookup_with_duplicates", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.NONE, + updateMode=UpdateMode.NONE, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -552,18 +552,18 @@ class CompactTestParams: start_idx=1, end_idx=4, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), states=None, ), }, ), ), - # Test case for compaction with EmbdUpdateMode.FIRST + # Test case for compaction with UpdateMode.FIRST ( "multi_lookup_with_duplicates_first_mode", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.FIRST, + updateMode=UpdateMode.FIRST, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -591,7 +591,7 @@ class CompactTestParams: start_idx=1, end_idx=4, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), states=torch.tensor( [ @@ -613,7 +613,7 @@ class CompactTestParams: ( "multiple_tables_first_mode", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.FIRST, + updateMode=UpdateMode.FIRST, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -643,13 +643,13 @@ class CompactTestParams: start_idx=1, end_idx=3, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5]), states=torch.tensor( [[10.0], [20.0], [30.0], [40.0], [50.0]] ), ), - "table_fqn_2": DeltaRows( + "table_fqn_2": UniqueRows( ids=torch.tensor([10, 20, 30, 40, 50]), states=torch.tensor( [[100.0], [200.0], [300.0], [400.0], [500.0]] @@ -658,11 +658,11 @@ class CompactTestParams: }, ), ), - # Test case for compaction with EmbdUpdateMode.LAST + # Test case for compaction with UpdateMode.LAST ( "multi_lookup_with_duplicates_last_mode", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.LAST, + updateMode=UpdateMode.LAST, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -690,7 +690,7 @@ class CompactTestParams: start_idx=1, end_idx=4, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), states=torch.tensor( [ @@ -712,7 +712,7 @@ class CompactTestParams: ( "multiple_tables_last_mode", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.LAST, + updateMode=UpdateMode.LAST, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -742,13 +742,13 @@ class CompactTestParams: start_idx=1, end_idx=3, expected_delta={ - "table_fqn_1": DeltaRows( + "table_fqn_1": UniqueRows( ids=torch.tensor([1, 2, 3, 4, 5]), states=torch.tensor( [[10.0], [20.0], [35.0], [40.0], [50.0]] ), ), - "table_fqn_2": DeltaRows( + "table_fqn_2": UniqueRows( ids=torch.tensor([10, 20, 30, 40, 50]), states=torch.tensor( [[100.0], [200.0], [350.0], [400.0], [500.0]] @@ -761,7 +761,7 @@ class CompactTestParams: ( "invalid_indices", CompactTestParams( - embdUpdateMode=EmbdUpdateMode.NONE, + updateMode=UpdateMode.NONE, table_fqn_to_lookups={ "table_fqn_1": [ IndexedLookup( @@ -783,15 +783,15 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: """ Test the compact method of DeltaStore. """ - # Create a DeltaStore with the specified embdUpdateMode - delta_store = DeltaStore(embdUpdateMode=test_params.embdUpdateMode) + # Create a DeltaStoreTrec with the specified updateMode + delta_store = DeltaStoreTrec(updateMode=test_params.updateMode) # Populate the DeltaStore with the test lookups for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): for lookup in lookup_list: delta_store.append( batch_idx=lookup.batch_idx, - table_fqn=table_fqn, + fqn=table_fqn, ids=lookup.ids, states=lookup.states, ) @@ -806,8 +806,8 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: delta_store.compact( start_idx=test_params.start_idx, end_idx=test_params.end_idx ) - # Verify the result using get_delta method - delta_result = delta_store.get_delta() + # Verify the result using get_unique method + delta_result = delta_store.get_unique() # compare all fqns in the result for table_fqn, delta_rows in test_params.expected_delta.items(): diff --git a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py index a92a9b286..a46268a3b 100644 --- a/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py +++ b/torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py @@ -13,6 +13,9 @@ import torch import torchrec from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_table_batched_embeddings_ops import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from parameterized import parameterized from torch import nn @@ -444,7 +447,7 @@ def test_fqn_to_feature_names( ), ), ( - "get_delta", + "get_unique", ModelDeltaTrackerInputTestParams( embedding_config_type=EmbeddingConfig, embedding_tables=[ @@ -461,7 +464,7 @@ def test_fqn_to_feature_names( model_tracker_config=ModelTrackerConfig(), ), TrackerNotInitOutputTestParams( - dmp_tracker_atter="get_delta", + dmp_tracker_atter="get_unique", ), ), ] @@ -1463,6 +1466,196 @@ def test_multiple_consumers( output_params=output_params, ) + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_LAST, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_momentum( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_momentum, + world_size=self.world_size, + test_params=test_params, + ) + + @parameterized.expand( + [ + ( + "EC_and_single_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1"], + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.MOMENTUM_DIFF, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ( + "EBC_and_multiple_feature", + ModelDeltaTrackerInputTestParams( + embedding_config_type=EmbeddingBagConfig, + embedding_tables=[ + EmbeddingTableProps( + embedding_table_config=EmbeddingBagConfig( + name="sparse_table_1", + num_embeddings=NUM_EMBEDDINGS, + embedding_dim=EMBEDDING_DIM, + feature_names=["f1", "f2"], + pooling=PoolingType.SUM, + ), + sharding=ShardingType.ROW_WISE, + ), + ], + model_tracker_config=ModelTrackerConfig( + tracking_mode=TrackingMode.ROWWISE_ADAGRAD, + delete_on_read=True, + ), + model_inputs=[ + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14]), + offsets=torch.tensor([0, 2, 2, 4, 6, 7, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([8, 10, 12, 14, 0, 2, 4, 6]), + offsets=torch.tensor([0, 2, 2, 4, 6, 6, 8]), + ), + ModelInput( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 0, 0, 4, 4, 4, 8]), + ), + ], + ), + ), + ] + ) + @skip_if_asan + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2+ GPUs") + def test_duplication_with_rowwise_adagrad( + self, + _test_name: str, + test_params: ModelDeltaTrackerInputTestParams, + ) -> None: + self._run_multi_process_test( + callable=_test_duplication_with_rowwise_adagrad, + world_size=self.world_size, + test_params=test_params, + ) + def _test_fqn_to_feature_names( rank: int, @@ -1546,7 +1739,7 @@ def _test_id_mode( tracked_out.sum().backward() baseline_out.sum().backward() - delta_ids = dt.get_delta_ids() + delta_ids = dt.get_unique_ids() table_fqns = dt.fqn_to_feature_names().keys() @@ -1650,7 +1843,7 @@ def _test_embedding_mode( tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() table_fqns = dt.fqn_to_feature_names().keys() table_fqns_list = list(table_fqns) @@ -1771,7 +1964,7 @@ def _test_multiple_get( unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta() + delta_rows = dt.get_unique() # Verify that the current batch index is correct unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1) @@ -1842,7 +2035,7 @@ def _test_multiple_consumer( unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) tracked_out.sum().backward() baseline_out.sum().backward() - delta_rows = dt.get_delta_ids(consumer=consumer) + delta_rows = dt.get_unique_ids(consumer=consumer) # Verify that the current batch index is correct unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1) @@ -1859,3 +2052,124 @@ def _test_multiple_consumer( and returned.allclose(expected_ids), f"{i=}, {table_fqn=}, mismatch {returned=} vs {expected_ids=}", ) + + +def _test_duplication_with_momentum( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + """ + Test momentum tracking functionality in model delta tracker. + + Validates that the tracker correctly captures and stores momentum values from + optimizer states when using TrackingMode.MOMENTUM_LAST mode. + """ + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + ) + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + dt_model_opt.step() + baseline_opt.step() + + delta_rows = dt.get_unique() + for table_fqn in table_fqns_list: + ids = delta_rows[table_fqn].ids + states = none_throws(delta_rows[table_fqn].states) + + unittest.TestCase().assertTrue(states is not None) + unittest.TestCase().assertTrue(ids.numel() == states.numel()) + unittest.TestCase().assertTrue(bool((states != 0).all().item())) + + +def _test_duplication_with_rowwise_adagrad( + rank: int, + world_size: int, + test_params: ModelDeltaTrackerInputTestParams, +) -> None: + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + ) as ctx: + dt_model, baseline_model = get_models( + rank=rank, + world_size=world_size, + ctx=ctx, + embedding_config_type=test_params.embedding_config_type, + tables=test_params.embedding_tables, + config=test_params.model_tracker_config, + optimizer_type=OptimType.EXACT_ROWWISE_ADAGRAD, + ) + + # read momemtum directly from the table + tbe: SplitTableBatchedEmbeddingBagsCodegen = ( + ( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `ec`. + dt_model._dmp_wrapped_module.module.ec._lookups[0] + ._emb_modules[0] + .emb_module + ) + if test_params.embedding_config_type == EmbeddingConfig + else ( + dt_model._dmp_wrapped_module.module.ebc._lookups[0] # pyre-ignore + ._emb_modules[0] + .emb_module + ) + ) + assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen) + start_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + dt_model_opt = torch.optim.Adam(dt_model.parameters(), lr=0.1) + baseline_opt = torch.optim.Adam(baseline_model.parameters(), lr=0.1) + features_list = model_input_generator(test_params.model_inputs, rank) + + dt = dt_model.get_model_tracker() + table_fqns = dt.fqn_to_feature_names().keys() + table_fqns_list = list(table_fqns) + + for features in features_list: + tracked_out = dt_model(features) + baseline_out = baseline_model(features) + unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out)) + tracked_out.sum().backward() + baseline_out.sum().backward() + + dt_model_opt.step() + baseline_opt.step() + + end_momentums = tbe.split_optimizer_states()[0][0].detach().clone() + + delta_rows = dt.get_unique() + table_fqn = table_fqns_list[0] + + ids = delta_rows[table_fqn].ids + tracked_momentum = none_throws(delta_rows[table_fqn].states) + unittest.TestCase().assertTrue(tracked_momentum is not None) + unittest.TestCase().assertTrue(ids.numel() == tracked_momentum.numel()) + unittest.TestCase().assertTrue(bool((tracked_momentum != 0).all().item())) + + expected_momentum = end_momentums[ids] - start_momentums[ids] + unittest.TestCase().assertTrue(tracked_momentum.allclose(expected_momentum)) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index cec95af91..3fbb70063 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -23,13 +23,14 @@ class IndexedLookup: batch_idx: int ids: torch.Tensor states: Optional[torch.Tensor] + compact: bool = False @dataclass -class DeltaRows: +class UniqueRows: r""" Data class as an interface for returning and storing compacted ids and embeddings or optimizer states. - compact(List[IndexedLookup]) -> DeltaRows + compact(List[IndexedLookup]) -> UniqueRows """ ids: torch.Tensor @@ -41,16 +42,24 @@ class TrackingMode(Enum): Tracking mode for ``ModelDeltaTracker``. Enums: - ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. - EMBEDDING: Tracks both row IDs and their corresponding embedding values, - enabling precise top-k result calculations. However, this option comes with increased memory usage. + ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. + EMBEDDING: Tracks both row IDs and their corresponding embedding values, + enabling precise top-k result calculations. However, this option comes + with increased memory usage. + MOMENTUM_LAST: Tracks both row IDs and their corresponding momentum values. This mode + supports approximate top-k delta-row selection. + MOMENTUM_DIFF: Tracks both row IDs and their corresponding momentum difference values. + ROWWISE_ADAGRAD: Tracks both row IDs and their corresponding rowwise adagrad states. """ ID_ONLY = "id_only" EMBEDDING = "embedding" + MOMENTUM_LAST = "momentum_last" + MOMENTUM_DIFF = "momentum_diff" + ROWWISE_ADAGRAD = "rowwise_adagrad" -class EmbdUpdateMode(Enum): +class UpdateMode(Enum): r""" To identify which embedding value to store while tracking.