Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from torchrec.distributed.model_tracker.types import (
IndexedLookup,
RawIndexedLookup,
UniqueRows,
UpdateMode,
)
Expand Down Expand Up @@ -90,7 +91,7 @@ def append(
batch_idx: int,
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
"""
Expand Down Expand Up @@ -162,12 +163,12 @@ def append(
batch_idx: int,
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids)
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
)
self.per_fqn_lookups[fqn] = table_fqn_lookup

Expand Down Expand Up @@ -264,3 +265,64 @@ def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
ids=compact_ids, states=compact_states, mode=self.updateMode
)
return delta_per_table_fqn


class RawIdTrackerStore(DeltaStore):
"""
RawIdTrackerStore is a concrete implementation of DeltaStore that stores and manages raw ids tracked by RawIdTracker.
"""

def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None:
super().__init__(updateMode)
self.updateMode = updateMode
self.per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {}

def append(
self,
batch_idx: int,
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids)
)
self.per_fqn_lookups[fqn] = table_fqn_lookup

def delete(self, up_to_idx: Optional[int] = None) -> None:
"""
Delete all idx from the store up to `up_to_idx`
"""
if up_to_idx is None:
# If up_to_idx is None, delete all lookups
self.per_fqn_lookups = {}
else:
# lookups are sorted by idx.
up_to_idx = none_throws(up_to_idx)
for table_fqn, lookups in self.per_fqn_lookups.items():
# remove all lookups up to up_to_idx
self.per_fqn_lookups[table_fqn] = [
lookup for lookup in lookups if lookup.batch_idx >= up_to_idx
]

def compact(self, start_idx: int, end_idx: int) -> None:
pass

def get_indexed_lookups(
self, start_idx: int, end_idx: int
) -> Dict[str, List[RawIndexedLookup]]:
r"""
Return all unique/delta ids per table from the Delta Store.
"""
per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
indexices = [h.batch_idx for h in lookups]
index_l = bisect_left(indexices, start_idx)
index_r = bisect_left(indexices, end_idx)
per_fqn_lookups[table_fqn] = lookups[index_l:index_r]
return per_fqn_lookups

def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
return {}
5 changes: 2 additions & 3 deletions torchrec/distributed/model_tracker/trackers/raw_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ShardedManagedCollisionEmbeddingBagCollection,
)
from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection
from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec
from torchrec.distributed.model_tracker.delta_store import RawIdTrackerStore

from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
c: -1 for c in (self._consumers or [self.DEFAULT_CONSUMER])
}

self.store: DeltaStoreTrec = DeltaStoreTrec()
self.store: RawIdTrackerStore = RawIdTrackerStore()

# Mapping feature name to corresponding FQNs. This is used for retrieving
# the FQN associated with a given feature name in record_lookup().
Expand Down Expand Up @@ -212,7 +212,6 @@ def record_lookup(
batch_idx=self.curr_batch_idx,
fqn=table_fqn,
ids=torch.cat(ids_list),
states=None,
raw_ids=torch.cat(per_table_raw_ids[table_fqn]),
)

Expand Down
12 changes: 11 additions & 1 deletion torchrec/distributed/model_tracker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@ class IndexedLookup:
batch_idx: int
ids: torch.Tensor
states: Optional[torch.Tensor]
raw_ids: Optional[torch.Tensor] = None
compact: bool = False


@dataclass
class RawIndexedLookup:
r"""
Data class for storing per batch lookedup ids and embeddings or optimizer states.
"""

batch_idx: int
ids: torch.Tensor
raw_ids: Optional[torch.Tensor] = None


@dataclass
class UniqueRows:
r"""
Expand Down
Loading