Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# pyre-strict
from bisect import bisect_left
from typing import Dict, List, Optional

import torch
from torchrec.distributed.model_tracker.types import (
DeltaRows,
EmbdUpdateMode,
IndexedLookup,
)
from torchrec.distributed.utils import none_throws


def _compute_unique_rows(
ids: List[torch.Tensor],
embeddings: Optional[List[torch.Tensor]],
mode: EmbdUpdateMode,
) -> DeltaRows:
r"""
To calculate unique ids and embeddings
"""
if mode == EmbdUpdateMode.NONE:
assert (
embeddings is None
), f"{mode=} == EmbdUpdateMode.NONE but received embeddings"
unique_ids = torch.cat(ids).unique(return_inverse=False)
return DeltaRows(ids=unique_ids, embeddings=None)
else:
assert (
embeddings is not None
), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings"

cat_ids = torch.cat(ids)
cat_embeddings = torch.cat(embeddings)

if mode == EmbdUpdateMode.LAST:
cat_ids = cat_ids.flip(dims=[0])
cat_embeddings = cat_embeddings.flip(dims=[0])

# Get unique ids and inverse mapping (each element's index in unique_ids).
unique_ids, inverse = cat_ids.unique(sorted=False, return_inverse=True)

# Create a tensor of original indices. This will be used to find first occurrences of ids.
all_indices = torch.arange(cat_ids.size(0), device=cat_ids.device)

# Initialize tensor for first occurrence indices (filled with a high value).
first_occurrence = torch.full(
(unique_ids.size(0),),
cat_ids.size(0),
dtype=torch.int64,
device=cat_ids.device,
)

# Scatter indices using inverse mapping and reduce with "amin" to get first or last (if reversed) occurrence per unique id.
first_occurrence = first_occurrence.scatter_reduce(
0, inverse, all_indices, reduce="amin"
)

# Use first occurrence indices to select corresponding embedding row.
unique_embedings = cat_embeddings[first_occurrence]
return DeltaRows(ids=unique_ids, embeddings=unique_embedings)


class DeltaStore:
"""
DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across
various batches during training, designed to be used with TorchRecs ModelDeltaTracker.
It maintains a CUDA in-memory representation of requested ids and embeddings/states,
providing a way to compact and get delta updates for each embedding table.

The class supports different embedding update modes (NONE, FIRST, LAST) to determine
how to handle duplicate ids when compacting or retrieving embeddings.

"""

def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
self.embdUpdateMode = embdUpdateMode
self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}

def append(
self,
batch_idx: int,
table_fqn: str,
ids: torch.Tensor,
embeddings: Optional[torch.Tensor],
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, [])
table_fqn_lookup.append(
IndexedLookup(batch_idx=batch_idx, ids=ids, embeddings=embeddings)
)
self.per_fqn_lookups[table_fqn] = table_fqn_lookup

def delete(self, up_to_idx: Optional[int] = None) -> None:
"""
Delete all idx from the store up to `up_to_idx`
"""
if up_to_idx is None:
# If up_to_idx is None, delete all lookups
self.per_fqn_lookups = {}
else:
# lookups are sorted by idx.
up_to_idx = none_throws(up_to_idx)
for table_fqn, lookups in self.per_fqn_lookups.items():
# remove all lookups up to up_to_idx
self.per_fqn_lookups[table_fqn] = [
lookup for lookup in lookups if lookup.batch_idx >= up_to_idx
]

def compact(self, start_idx: int, end_idx: int) -> None:
r"""
Compact (ids, embeddings) in batch index range from start_idx, curr_batch_idx.
"""
assert (
start_idx < end_idx
), f"start_idx {start_idx} must be smaller then end_idx, but got {end_idx}"

new_per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
indexices = [h.batch_idx for h in lookups]
index_l = bisect_left(indexices, start_idx)
index_r = bisect_left(indexices, end_idx)
lookups_to_compact = lookups[index_l:index_r]
if len(lookups_to_compact) <= 1:
new_per_fqn_lookups[table_fqn] = lookups
continue
ids = [lookup.ids for lookup in lookups_to_compact]
embeddings = (
[none_throws(lookup.embeddings) for lookup in lookups_to_compact]
if self.embdUpdateMode != EmbdUpdateMode.NONE
else None
)
delta_rows = _compute_unique_rows(
ids=ids, embeddings=embeddings, mode=self.embdUpdateMode
)
new_per_fqn_lookups[table_fqn] = (
lookups[:index_l]
+ [
IndexedLookup(
batch_idx=start_idx,
ids=delta_rows.ids,
embeddings=delta_rows.embeddings,
)
]
+ lookups[index_r:]
)
self.per_fqn_lookups = new_per_fqn_lookups

def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
r"""
Return all unique/delta ids per table from the Delta Store.
"""

delta_per_table_fqn: Dict[str, DeltaRows] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
compact_ids = [
lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx
]
compact_embeddings = (
[
none_throws(lookup.embeddings)
for lookup in lookups
if lookup.batch_idx >= from_idx
]
if self.embdUpdateMode != EmbdUpdateMode.NONE
else None
)

delta_per_table_fqn[table_fqn] = _compute_unique_rows(
ids=compact_ids, embeddings=compact_embeddings, mode=self.embdUpdateMode
)
return delta_per_table_fqn
114 changes: 114 additions & 0 deletions torchrec/distributed/model_tracker/model_delta_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import Dict, List, Optional, Union

import torch

from torch import nn
from torchrec.distributed.embedding import ShardedEmbeddingCollection
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
from torchrec.distributed.model_tracker.types import (
DeltaRows,
EmbdUpdateMode,
TrackingMode,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = {
# Only IDs are tracked, no additional state is stored.
TrackingMode.ID_ONLY: EmbdUpdateMode.NONE,
# TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that
# the earliest embedding values are stored since the last checkpoint or snapshot.
# This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk().
TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST,
}

# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection]


class ModelDeltaTracker:
r"""

ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
online training. Unique IDs or states can be retrieved by calling the get_unique() method.

Args:
model (nn.Module): the model to track.
consumers (List[str], optional): list of consumers to track. Each consumer will
have its own batch offset index. Every get_unique_ids invocation will
only return the new ids for the given consumer since last get_unique_ids
call.
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
"""

DEFAULT_CONSUMER: str = "default"

def __init__(
self,
model: nn.Module,
consumers: Optional[List[str]] = None,
delete_on_read: bool = True,
mode: TrackingMode = TrackingMode.ID_ONLY,
) -> None:
self._model = model
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
self._delete_on_read = delete_on_read
self._mode = mode
pass

def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
"""
Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states.

Args:
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
states (torch.Tensor): the states to record.
"""
pass

def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
"""
Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN.

Args:
consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer.
"""
return {}

def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
"""
Returns a mapping from FQN to feature names for a given module.

Args:
module (nn.Module): the module to retrieve feature names for.
"""
return {}

def clear(self, consumer: Optional[str] = None) -> None:
"""
Clear tracked IDs for a given consumer.

Args:
consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer.
"""
pass

def compact(self, start_idx: int, end_idx: int) -> None:
"""
Compact tracked IDs for a given range of indices.

Args:
start_idx (int): Starting index for compaction.
end_idx (int): Ending index for compaction.
"""
pass
Loading
Loading