From a2677641798bc6aa07f532844391ec0709ddad0f Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Wed, 12 Nov 2025 22:34:25 -0800 Subject: [PATCH] Raw id tracker store (#3541) Summary: Introducing a seperate store for raw id tracker for specifically trackiing ids from RawIdTracker. Reviewed By: chouxi Differential Revision: D86524689 --- .../distributed/model_tracker/delta_store.py | 68 ++++++++++++++++++- .../model_tracker/trackers/raw_id_tracker.py | 5 +- torchrec/distributed/model_tracker/types.py | 12 +++- 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index cfac71b8c..d067c055d 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -14,6 +14,7 @@ import torch from torchrec.distributed.model_tracker.types import ( IndexedLookup, + RawIndexedLookup, UniqueRows, UpdateMode, ) @@ -90,7 +91,7 @@ def append( batch_idx: int, fqn: str, ids: torch.Tensor, - states: Optional[torch.Tensor], + states: Optional[torch.Tensor] = None, raw_ids: Optional[torch.Tensor] = None, ) -> None: """ @@ -162,12 +163,12 @@ def append( batch_idx: int, fqn: str, ids: torch.Tensor, - states: Optional[torch.Tensor], + states: Optional[torch.Tensor] = None, raw_ids: Optional[torch.Tensor] = None, ) -> None: table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( - IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids) + IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) ) self.per_fqn_lookups[fqn] = table_fqn_lookup @@ -264,3 +265,64 @@ def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: ids=compact_ids, states=compact_states, mode=self.updateMode ) return delta_per_table_fqn + + +class RawIdTrackerStore(DeltaStore): + """ + RawIdTrackerStore is a concrete implementation of DeltaStore that stores and manages raw ids tracked by RawIdTracker. + """ + + def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None: + super().__init__(updateMode) + self.updateMode = updateMode + self.per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {} + + def append( + self, + batch_idx: int, + fqn: str, + ids: torch.Tensor, + states: Optional[torch.Tensor] = None, + raw_ids: Optional[torch.Tensor] = None, + ) -> None: + table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) + table_fqn_lookup.append( + RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids) + ) + self.per_fqn_lookups[fqn] = table_fqn_lookup + + def delete(self, up_to_idx: Optional[int] = None) -> None: + """ + Delete all idx from the store up to `up_to_idx` + """ + if up_to_idx is None: + # If up_to_idx is None, delete all lookups + self.per_fqn_lookups = {} + else: + # lookups are sorted by idx. + up_to_idx = none_throws(up_to_idx) + for table_fqn, lookups in self.per_fqn_lookups.items(): + # remove all lookups up to up_to_idx + self.per_fqn_lookups[table_fqn] = [ + lookup for lookup in lookups if lookup.batch_idx >= up_to_idx + ] + + def compact(self, start_idx: int, end_idx: int) -> None: + pass + + def get_indexed_lookups( + self, start_idx: int, end_idx: int + ) -> Dict[str, List[RawIndexedLookup]]: + r""" + Return all unique/delta ids per table from the Delta Store. + """ + per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {} + for table_fqn, lookups in self.per_fqn_lookups.items(): + indexices = [h.batch_idx for h in lookups] + index_l = bisect_left(indexices, start_idx) + index_r = bisect_left(indexices, end_idx) + per_fqn_lookups[table_fqn] = lookups[index_l:index_r] + return per_fqn_lookups + + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: + return {} diff --git a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py index 42eeb90e9..b64792912 100644 --- a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py +++ b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py @@ -23,7 +23,7 @@ ShardedManagedCollisionEmbeddingBagCollection, ) from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection -from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec +from torchrec.distributed.model_tracker.delta_store import RawIdTrackerStore from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows @@ -71,7 +71,7 @@ def __init__( c: -1 for c in (self._consumers or [self.DEFAULT_CONSUMER]) } - self.store: DeltaStoreTrec = DeltaStoreTrec() + self.store: RawIdTrackerStore = RawIdTrackerStore() # Mapping feature name to corresponding FQNs. This is used for retrieving # the FQN associated with a given feature name in record_lookup(). @@ -212,7 +212,6 @@ def record_lookup( batch_idx=self.curr_batch_idx, fqn=table_fqn, ids=torch.cat(ids_list), - states=None, raw_ids=torch.cat(per_table_raw_ids[table_fqn]), ) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index f279f40e1..15988e872 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -23,10 +23,20 @@ class IndexedLookup: batch_idx: int ids: torch.Tensor states: Optional[torch.Tensor] - raw_ids: Optional[torch.Tensor] = None compact: bool = False +@dataclass +class RawIndexedLookup: + r""" + Data class for storing per batch lookedup ids and embeddings or optimizer states. + """ + + batch_idx: int + ids: torch.Tensor + raw_ids: Optional[torch.Tensor] = None + + @dataclass class UniqueRows: r"""