Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs)
self.post_lookup_tracker_fn(features, embs, self)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
Expand Down
55 changes: 52 additions & 3 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
from abc import ABC
from collections import OrderedDict
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -208,6 +208,10 @@ def __init__(
)

self.grouped_configs = grouped_configs
# Model tracker function to tracker optimizer state
self.optim_state_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
] = None

def _create_embedding_kernel(
self,
Expand Down Expand Up @@ -315,7 +319,13 @@ def forward(
self._feature_splits,
)
for emb_op, features in zip(self._emb_modules, features_by_group):
embeddings.append(emb_op(features).view(-1))
lookup = emb_op(features).view(-1)
embeddings.append(lookup)

# Model tracker optimizer state function, will only be set called
# when model tracker is configured to track optimizer state
if self.optim_state_tracker_fn is not None:
self.optim_state_tracker_fn(features, lookup, emb_op)

return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)

Expand Down Expand Up @@ -420,6 +430,21 @@ def purge(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.purge()

def register_optim_state_tracker_fn(
self,
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
],
) -> None:
"""
Model tracker function to tracker optimizer state

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done.

"""
self.optim_state_tracker_fn = record_fn


class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]):
"""
Expand Down Expand Up @@ -519,6 +544,10 @@ def __init__(
if scale_weight_gradients and get_gradient_division()
else 1
)
# Model tracker function to tracker optimizer state
self.optim_state_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
] = None

def _create_embedding_kernel(
self,
Expand Down Expand Up @@ -678,7 +707,12 @@ def forward(
features._weights, self._scale_gradient_factor
)

embeddings.append(emb_op(features))
lookup = emb_op(features)
embeddings.append(lookup)
# Model tracker optimizer state function, will only be set called
# when model tracker is configured to track optimizer state
if self.optim_state_tracker_fn is not None:
self.optim_state_tracker_fn(features, lookup, emb_op)

if features.variable_stride_per_key() and len(self._emb_modules) > 1:
stride_per_rank_per_key = list(
Expand Down Expand Up @@ -811,6 +845,21 @@ def purge(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.purge()

def register_optim_state_tracker_fn(
self,
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
],
) -> None:
"""
Model tracker function to tracker optimizer state

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done.

"""
self.optim_state_tracker_fn = record_fn


class MetaInferGroupedEmbeddingsLookup(
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn
Expand Down
8 changes: 5 additions & 3 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor], None]
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
] = None
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None

Expand Down Expand Up @@ -444,14 +444,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]

def register_post_lookup_tracker_fn(
self,
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
],
) -> None:
"""
Register a function to be called after lookup is done. This is used for
tracking the lookup results and optimizer states.

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor,Optional[nn.Module]], None]): A custom record function to be called after lookup is done.

"""
if self.post_lookup_tracker_fn is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs)
self.post_lookup_tracker_fn(features, embs, self)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST,
Expand Down
14 changes: 7 additions & 7 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
Expand Down Expand Up @@ -293,7 +293,7 @@ def __init__(
if init_data_parallel:
self.init_data_parallel()

self.model_delta_tracker: Optional[ModelDeltaTracker] = (
self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
if model_tracker_config is not None
else None
Expand Down Expand Up @@ -369,9 +369,9 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:

def _init_delta_tracker(
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
) -> ModelDeltaTracker:
) -> ModelDeltaTrackerTrec:
# Init delta tracker if config is provided
return ModelDeltaTracker(
return ModelDeltaTrackerTrec(
model=module,
consumers=model_tracker_config.consumers,
delete_on_read=model_tracker_config.delete_on_read,
Expand Down Expand Up @@ -456,7 +456,7 @@ def init_parameters(module: nn.Module) -> None:

module.apply(init_parameters)

def get_model_tracker(self) -> ModelDeltaTracker:
def get_model_tracker(self) -> ModelDeltaTrackerTrec:
"""
Returns the model tracker if it exists.
"""
Expand All @@ -466,14 +466,14 @@ def get_model_tracker(self) -> ModelDeltaTracker:
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker

def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
def get_unique(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
"""
Returns the delta rows for the given consumer.
"""
assert (
self.model_delta_tracker is not None
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker.get_delta(consumer)
return self.model_delta_tracker.get_unique(consumer)

def sparse_grad_parameter_names(
self, destination: Optional[List[str]] = None, prefix: str = ""
Expand Down
89 changes: 81 additions & 8 deletions torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


# pyre-strict
from abc import ABC, abstractmethod
from bisect import bisect_left
from typing import Dict, List, Optional

Expand Down Expand Up @@ -67,34 +68,106 @@ def _compute_unique_rows(
return DeltaRows(ids=unique_ids, states=unique_states)


class DeltaStore:
class DeltaStore(ABC):
"""
DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across
various batches during training, designed to be used with TorchRecs ModelDeltaTracker.
It maintains a CUDA in-memory representation of requested ids and embeddings/states,
DeltaStore is an abstract base class that defines the interface for storing and managing
local delta (row) updates for embeddings/states across various batches during training.

Implementations should maintain a representation of requested ids and embeddings/states,
providing a way to compact and get delta updates for each embedding table.

The class supports different embedding update modes (NONE, FIRST, LAST) to determine
how to handle duplicate ids when compacting or retrieving embeddings.
"""

@abstractmethod
def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
pass

@abstractmethod
def append(
self,
batch_idx: int,
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
) -> None:
"""
Append a batch of ids and states to the store for a specific table.

Args:
batch_idx: The batch index
table_fqn: The fully qualified name of the table
ids: The tensor of ids to append
states: Optional tensor of states to append
"""
pass

@abstractmethod
def delete(self, up_to_idx: Optional[int] = None) -> None:
"""
Delete all idx from the store up to `up_to_idx`

Args:
up_to_idx: Optional index up to which to delete lookups
"""
pass

@abstractmethod
def compact(self, start_idx: int, end_idx: int) -> None:
"""
Compact (ids, embeddings) in batch index range from start_idx to end_idx.

Args:
start_idx: The starting batch index
end_idx: The ending batch index
"""
pass

@abstractmethod
def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
"""
Return all unique/delta ids per table from the Delta Store.

Args:
from_idx: The batch index from which to get deltas

Returns:
A dictionary mapping table FQNs to their delta rows
"""
pass


class DeltaStoreTrec(DeltaStore):
"""
DeltaStoreTrec is a concrete implementation of DeltaStore that stores and manages
local delta (row) updates for embeddings/states across various batches during training,
designed to be used with TorchRecs ModelDeltaTracker.

It maintains a CUDA in-memory representation of requested ids and embeddings/states,
providing a way to compact and get delta updates for each embedding table.

The class supports different embedding update modes (NONE, FIRST, LAST) to determine
how to handle duplicate ids when compacting or retrieving embeddings.
"""

def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
super().__init__(embdUpdateMode)
self.embdUpdateMode = embdUpdateMode
self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}

def append(
self,
batch_idx: int,
table_fqn: str,
fqn: str,
ids: torch.Tensor,
states: Optional[torch.Tensor],
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, [])
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
)
self.per_fqn_lookups[table_fqn] = table_fqn_lookup
self.per_fqn_lookups[fqn] = table_fqn_lookup

def delete(self, up_to_idx: Optional[int] = None) -> None:
"""
Expand Down Expand Up @@ -151,7 +224,7 @@ def compact(self, start_idx: int, end_idx: int) -> None:
)
self.per_fqn_lookups = new_per_fqn_lookups

def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
r"""
Return all unique/delta ids per table from the Delta Store.
"""
Expand Down
Loading
Loading