Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 251 additions & 20 deletions torchrec/distributed/model_tracker/model_delta_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import Dict, List, Optional, Union
import logging as logger
from collections import Counter, OrderedDict
from typing import Dict, Iterable, List, Optional

import torch

from torch import nn
from torchrec.distributed.embedding import ShardedEmbeddingCollection
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
from torchrec.distributed.model_tracker.delta_store import DeltaStore
from torchrec.distributed.model_tracker.types import (
DeltaRows,
EmbdUpdateMode,
Expand All @@ -30,7 +33,7 @@
}

# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection]
SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection)


class ModelDeltaTracker:
Expand All @@ -39,16 +42,19 @@ class ModelDeltaTracker:
ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
online training. Unique IDs or states can be retrieved by calling the get_unique() method.
online training. Unique IDs or states can be retrieved by calling the get_delta() method.

Args:
model (nn.Module): the model to track.
consumers (List[str], optional): list of consumers to track. Each consumer will
have its own batch offset index. Every get_unique_ids invocation will
only return the new ids for the given consumer since last get_unique_ids
call.
have its own batch offset index. Every get_delta and get_delta_ids invocation will
only return the new values for the given consumer since last call.
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
auto_compact (bool, optional): Trigger compaction automatically during communication at each train cycle.
When set false, compaction is triggered at get_delta() call. Default: False.
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
fqns_to_skip (Iterable[str], optional): list of FQNs to skip tracking. Default: None.

"""

DEFAULT_CONSUMER: str = "default"
Expand All @@ -58,41 +64,242 @@ def __init__(
model: nn.Module,
consumers: Optional[List[str]] = None,
delete_on_read: bool = True,
auto_compact: bool = False,
mode: TrackingMode = TrackingMode.ID_ONLY,
fqns_to_skip: Iterable[str] = (),
) -> None:
self._model = model
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
self._delete_on_read = delete_on_read
self._auto_compact = auto_compact
self._mode = mode
pass
self._fqn_to_feature_map: Dict[str, List[str]] = {}
self._fqns_to_skip: Iterable[str] = fqns_to_skip

# per_consumer_batch_idx is used to track the batch index for each consumer.
# This is used to retrieve the delta values for a given consumer as well as
# start_ids for compaction window.
self.per_consumer_batch_idx: Dict[str, int] = {
c: -1 for c in (consumers or [self.DEFAULT_CONSUMER])
}
self.curr_batch_idx: int = 0
self.curr_compact_index: int = 0

self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode])

# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
self.tracked_modules: Dict[str, nn.Module] = {}
self.feature_to_fqn: Dict[str, str] = {}
# Generate the mapping from FQN to feature names.
self.fqn_to_feature_names()
# Validate the mode is supported for the given module
self._validate_mode()

# Mapping feature name to corresponding FQNs. This is used for retrieving
# the FQN associated with a given feature name in record_lookup().
for fqn, feature_names in self._fqn_to_feature_map.items():
for feature_name in feature_names:
if feature_name in self.feature_to_fqn:
logger.warn(f"Duplicate feature name: {feature_name} in fqn {fqn}")
continue
self.feature_to_fqn[feature_name] = fqn
logger.info(f"feature_to_fqn: {self.feature_to_fqn}")

def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
"""
Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states.
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.

This method is run post-lookup, after the embedding lookup has been performed,
as it needs access to both the input IDs and the resulting embeddings.

This function processes the input KeyedJaggedTensor and records either just the IDs
(in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode).

Args:
kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
"""

# In ID_ONLY mode, we only track feature IDs received in the current batch.
if self._mode == TrackingMode.ID_ONLY:
self.record_ids(kjt)
# In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
elif self._mode == TrackingMode.EMBEDDING:
self.record_embeddings(kjt, states)

else:
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")

def record_ids(self, kjt: KeyedJaggedTensor) -> None:
"""
Record Ids from a given KeyedJaggedTensor.

Args:
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
states (torch.Tensor): the states to record.
"""
pass
per_table_ids: Dict[str, List[torch.Tensor]] = {}
for key in kjt.keys():
table_fqn = self.feature_to_fqn[key]
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
ids_list.append(kjt[key].values())
per_table_ids[table_fqn] = ids_list

def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
for table_fqn, ids_list in per_table_ids.items():
self.store.append(
batch_idx=self.curr_batch_idx,
table_fqn=table_fqn,
ids=torch.cat(ids_list),
embeddings=None,
)

def record_embeddings(
self, kjt: KeyedJaggedTensor, embeddings: torch.Tensor
) -> None:
"""
Record Ids along with Embeddings from a given KeyedJaggedTensor and embeddings.

Args:
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
embeddings (torch.Tensor): the embeddings to record.
"""
Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN.
per_table_ids: Dict[str, List[torch.Tensor]] = {}
per_table_emb: Dict[str, List[torch.Tensor]] = {}
assert embeddings.numel() % kjt.values().numel() == 0, (
f"ids and embeddings size mismatch, expect [{kjt.values().numel()} * emb_dim], "
f"but got {embeddings.numel()}"
)
embeddings_2d = embeddings.view(kjt.values().numel(), -1)

offset: int = 0
for key in kjt.keys():
table_fqn = self.feature_to_fqn[key]
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
emb_list: List[torch.Tensor] = per_table_emb.get(table_fqn, [])

ids = kjt[key].values()
ids_list.append(ids)
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
offset += ids.numel()

per_table_ids[table_fqn] = ids_list
per_table_emb[table_fqn] = emb_list

for table_fqn, ids_list in per_table_ids.items():
self.store.append(
batch_idx=self.curr_batch_idx,
table_fqn=table_fqn,
ids=torch.cat(ids_list),
embeddings=torch.cat(per_table_emb[table_fqn]),
)

def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
"""
Return a dictionary of hit local IDs for each sparse feature. Ids are
first keyed by submodule FQN.

Args:
consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer.
consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer.
"""
return {}
per_table_delta_rows = self.get_delta(consumer)
return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()}

def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
"""
Returns a mapping from FQN to feature names for a given module.
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.

Args:
module (nn.Module): the module to retrieve feature names for.
consumer (str, optional): The consumer to retrieve delta values for. If not specified, "default" is used as the default consumer.
"""
return {}
consumer = consumer or self.DEFAULT_CONSUMER
assert (
consumer in self.per_consumer_batch_idx
), f"consumer {consumer} not present in {self.per_consumer_batch_idx.values()}"

index_end: int = self.curr_batch_idx + 1
index_start = max(self.per_consumer_batch_idx.values())

# In case of multiple consumers, it is possible that the previous consumer has already compact these indices
# and index_start could be equal to index_end, in which case we should not compact again.
if index_start < index_end:
self.compact(index_start, index_end)
tracker_rows = self.store.get_delta(
from_idx=self.per_consumer_batch_idx[consumer]
)
self.per_consumer_batch_idx[consumer] = index_end
if self._delete_on_read:
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
return tracker_rows

def get_tracked_modules(self) -> Dict[str, nn.Module]:
"""
Returns a dictionary of tracked modules.
"""
return self.tracked_modules

def fqn_to_feature_names(self) -> Dict[str, List[str]]:
"""
Returns a mapping of FQN to feature names from all Supported Modules [EmbeddingCollection and EmbeddingBagCollection] present in the given model.
"""
if (self._fqn_to_feature_map is not None) and len(self._fqn_to_feature_map) > 0:
return self._fqn_to_feature_map

table_to_feature_names: Dict[str, List[str]] = OrderedDict()
table_to_fqn: Dict[str, str] = OrderedDict()
for fqn, named_module in self._model.named_modules():
split_fqn = fqn.split(".")
# Skipping partial FQNs present in fqns_to_skip
# TODO: Validate if we need to support more complex patterns for skipping fqns
should_skip = False
for fqn_to_skip in self._fqns_to_skip:
if fqn_to_skip in split_fqn:
logger.info(f"Skipping {fqn} because it is part of fqns_to_skip")
should_skip = True
break
if should_skip:
continue
# Using FQNs of the embedding and mapping them to features as state_dict() API uses these to key states.
if isinstance(named_module, SUPPORTED_MODULES):
for table_name, config in named_module._table_name_to_config.items():
logger.info(
f"Found {table_name} for {fqn} with features {config.feature_names}"
)
table_to_feature_names[table_name] = config.feature_names
self.tracked_modules[self._clean_fqn_fn(fqn)] = named_module
for table_name in table_to_feature_names:
# Using the split FQN to get the exact table name matching. Otherwise, checking "table_name in fqn"
# will incorrectly match fqn with all the table names that have the same prefix
if table_name in split_fqn:
embedding_fqn = self._clean_fqn_fn(fqn)
if table_name in table_to_fqn:
# Sanity check for validating that we don't have more then one table mapping to same fqn.
logger.warning(
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
)
table_to_fqn[table_name] = embedding_fqn
logger.info(f"Table to fqn: {table_to_fqn}")
flatten_names = [
name for names in table_to_feature_names.values() for name in names
]
# TODO: Validate if there is a better way to handle duplicate feature names.
# Logging a warning if duplicate feature names are found across tables, but continue execution as this could be a valid case.
if len(set(flatten_names)) != len(flatten_names):
counts = Counter(flatten_names)
duplicates = [item for item, count in counts.items() if count > 1]
logger.warning(f"duplicate feature names found: {duplicates}")

fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
for table_name in table_to_feature_names:
if table_name not in table_to_fqn:
# This is likely unexpected, where we can't locate the FQN associated with this table.
logger.warning(
f"Table {table_name} not found in {table_to_fqn}, skipping"
)
continue
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
table_name
]
self._fqn_to_feature_map = fqn_to_feature_names
return fqn_to_feature_names

def clear(self, consumer: Optional[str] = None) -> None:
"""
Expand All @@ -101,7 +308,19 @@ def clear(self, consumer: Optional[str] = None) -> None:
Args:
consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer.
"""
pass
# 1. If consumer is None, delete globally.
if consumer is None:
self.store.delete()
return

assert (
consumer in self.per_consumer_batch_idx
), f"consumer {consumer} not found in {self.per_consumer_batch_idx.values()}"

# 2. For single consumer, we can just delete all ids
if len(self.per_consumer_batch_idx) == 1:
self.store.delete()
return

def compact(self, start_idx: int, end_idx: int) -> None:
"""
Expand All @@ -111,4 +330,16 @@ def compact(self, start_idx: int, end_idx: int) -> None:
start_idx (int): Starting index for compaction.
end_idx (int): Ending index for compaction.
"""
pass
self.store.compact(start_idx, end_idx)

def _clean_fqn_fn(self, fqn: str) -> str:
# strip DMP internal module FQN prefix to match state dict FQN
return fqn.replace("_dmp_wrapped_module.module.", "")

def _validate_mode(self) -> None:
"To validate the mode is supported for the given module"
for module in self.tracked_modules.values():
assert not (
isinstance(module, ShardedEmbeddingBagCollection)
and self._mode == TrackingMode.EMBEDDING
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."
Loading
Loading