Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs)
self.post_lookup_tracker_fn(features, embs, self)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
Expand Down
55 changes: 52 additions & 3 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
from abc import ABC
from collections import OrderedDict
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -208,6 +208,10 @@ def __init__(
)

self.grouped_configs = grouped_configs
# Model tracker function to tracker optimizer state
self.optim_state_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
] = None

def _create_embedding_kernel(
self,
Expand Down Expand Up @@ -315,7 +319,13 @@ def forward(
self._feature_splits,
)
for emb_op, features in zip(self._emb_modules, features_by_group):
embeddings.append(emb_op(features).view(-1))
lookup = emb_op(features).view(-1)
embeddings.append(lookup)

# Model tracker optimizer state function, will only be set called
# when model tracker is configured to track optimizer state
if self.optim_state_tracker_fn is not None:
self.optim_state_tracker_fn(features, lookup, emb_op)

return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)

Expand Down Expand Up @@ -420,6 +430,21 @@ def purge(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.purge()

def register_optim_state_tracker_fn(
self,
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
],
) -> None:
"""
Model tracker function to tracker optimizer state

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done.

"""
self.optim_state_tracker_fn = record_fn


class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]):
"""
Expand Down Expand Up @@ -519,6 +544,10 @@ def __init__(
if scale_weight_gradients and get_gradient_division()
else 1
)
# Model tracker function to tracker optimizer state
self.optim_state_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
] = None

def _create_embedding_kernel(
self,
Expand Down Expand Up @@ -678,7 +707,12 @@ def forward(
features._weights, self._scale_gradient_factor
)

embeddings.append(emb_op(features))
lookup = emb_op(features)
embeddings.append(lookup)
# Model tracker optimizer state function, will only be set called
# when model tracker is configured to track optimizer state
if self.optim_state_tracker_fn is not None:
self.optim_state_tracker_fn(features, lookup, emb_op)

if features.variable_stride_per_key() and len(self._emb_modules) > 1:
stride_per_rank_per_key = list(
Expand Down Expand Up @@ -811,6 +845,21 @@ def purge(self) -> None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
emb_module.purge()

def register_optim_state_tracker_fn(
self,
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
],
) -> None:
"""
Model tracker function to tracker optimizer state

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done.

"""
self.optim_state_tracker_fn = record_fn


class MetaInferGroupedEmbeddingsLookup(
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn
Expand Down
8 changes: 5 additions & 3 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def __init__(
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []
self.post_lookup_tracker_fn: Optional[
Callable[[KeyedJaggedTensor, torch.Tensor], None]
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
] = None
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None

Expand Down Expand Up @@ -444,14 +444,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]

def register_post_lookup_tracker_fn(
self,
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
record_fn: Callable[
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
],
) -> None:
"""
Register a function to be called after lookup is done. This is used for
tracking the lookup results and optimizer states.

Args:
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor,Optional[nn.Module]], None]): A custom record function to be called after lookup is done.

"""
if self.post_lookup_tracker_fn is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,7 @@ def compute_and_output_dist(
):
embs = lookup(features)
if self.post_lookup_tracker_fn is not None:
self.post_lookup_tracker_fn(features, embs)
self.post_lookup_tracker_fn(features, embs, self)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST,
Expand Down
30 changes: 22 additions & 8 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.sharding_plan import get_default_sharders
Expand All @@ -50,6 +50,7 @@
append_prefix,
copy_to_device,
filter_state_dict,
none_throws,
sharded_model_copy,
)
from torchrec.optim.fused import FusedOptimizerModule
Expand Down Expand Up @@ -293,7 +294,7 @@ def __init__(
if init_data_parallel:
self.init_data_parallel()

self.model_delta_tracker: Optional[ModelDeltaTracker] = (
self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
if model_tracker_config is not None
else None
Expand Down Expand Up @@ -369,9 +370,9 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:

def _init_delta_tracker(
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
) -> ModelDeltaTracker:
) -> ModelDeltaTrackerTrec:
# Init delta tracker if config is provided
return ModelDeltaTracker(
return ModelDeltaTrackerTrec(
model=module,
consumers=model_tracker_config.consumers,
delete_on_read=model_tracker_config.delete_on_read,
Expand Down Expand Up @@ -456,7 +457,20 @@ def init_parameters(module: nn.Module) -> None:

module.apply(init_parameters)

def get_model_tracker(self) -> ModelDeltaTracker:
def init_torchrec_delta_tracker(
self, model_tracker_config: ModelTrackerConfig
) -> ModelDeltaTrackerTrec:
"""
Initializes the model delta tracker if it doesn't exists.
"""
if self.model_delta_tracker is None:
self.model_delta_tracker = self._init_delta_tracker(
model_tracker_config, self._dmp_wrapped_module
)

return none_throws(self.model_delta_tracker)

def get_model_tracker(self) -> ModelDeltaTrackerTrec:
"""
Returns the model tracker if it exists.
"""
Expand All @@ -466,14 +480,14 @@ def get_model_tracker(self) -> ModelDeltaTracker:
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker

def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
def get_unique(self, consumer: Optional[str] = None) -> Dict[str, UniqueRows]:
"""
Returns the delta rows for the given consumer.
"""
assert (
self.model_delta_tracker is not None
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker.get_delta(consumer)
return self.model_delta_tracker.get_unique(consumer)

def sparse_grad_parameter_names(
self, destination: Optional[List[str]] = None, prefix: str = ""
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/model_tracker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
SUPPORTED_MODULES, # noqa
)
from torchrec.distributed.model_tracker.types import (
DeltaRows, # noqa
EmbdUpdateMode, # noqa
IndexedLookup, # noqa
ModelTrackerConfig, # noqa
TrackingMode, # noqa
UniqueRows, # noqa
UpdateMode, # noqa
)
Loading
Loading