Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# pyre-strict
from bisect import bisect_left
from typing import Dict, List, Optional

import torch
from torchrec.distributed.model_tracker.types import (
DeltaRows,
EmbdUpdateMode,
IndexedLookup,
)
from torchrec.distributed.utils import none_throws


def _compute_unique_rows(
ids: List[torch.Tensor],
embeddings: Optional[List[torch.Tensor]],
mode: EmbdUpdateMode,
) -> DeltaRows:
r"""
To calculate unique ids and embeddings
"""
if mode == EmbdUpdateMode.NONE:
assert (
embeddings is None
), f"{mode=} == EmbdUpdateMode.NONE but received embeddings"
unique_ids = torch.cat(ids).unique(return_inverse=False)
return DeltaRows(ids=unique_ids, embeddings=None)
else:
assert (
embeddings is not None
), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings"

cat_ids = torch.cat(ids)
cat_embeddings = torch.cat(embeddings)

if mode == EmbdUpdateMode.LAST:
cat_ids = cat_ids.flip(dims=[0])
cat_embeddings = cat_embeddings.flip(dims=[0])

# Get unique ids and inverse mapping (each element's index in unique_ids).
unique_ids, inverse = cat_ids.unique(sorted=False, return_inverse=True)

# Create a tensor of original indices. This will be used to find first occurrences of ids.
all_indices = torch.arange(cat_ids.size(0), device=cat_ids.device)

# Initialize tensor for first occurrence indices (filled with a high value).
first_occurrence = torch.full(
(unique_ids.size(0),),
cat_ids.size(0),
dtype=torch.int64,
device=cat_ids.device,
)

# Scatter indices using inverse mapping and reduce with "amin" to get first or last (if reversed) occurrence per unique id.
first_occurrence = first_occurrence.scatter_reduce(
0, inverse, all_indices, reduce="amin"
)

# Use first occurrence indices to select corresponding embedding row.
unique_embedings = cat_embeddings[first_occurrence]
return DeltaRows(ids=unique_ids, embeddings=unique_embedings)


class DeltaStore:
"""
DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across
various batches during training, designed to be used with TorchRecs ModelDeltaTracker.
It maintains a CUDA in-memory representation of requested ids and embeddings/states,
providing a way to compact and get delta updates for each embedding table.

The class supports different embedding update modes (NONE, FIRST, LAST) to determine
how to handle duplicate ids when compacting or retrieving embeddings.

"""

def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
self.embdUpdateMode = embdUpdateMode
self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}

def append(
self,
batch_idx: int,
table_fqn: str,
ids: torch.Tensor,
embeddings: Optional[torch.Tensor],
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, [])
table_fqn_lookup.append(
IndexedLookup(batch_idx=batch_idx, ids=ids, embeddings=embeddings)
)
self.per_fqn_lookups[table_fqn] = table_fqn_lookup

def delete(self, up_to_idx: Optional[int] = None) -> None:
"""
Delete all idx from the store up to `up_to_idx`
"""
if up_to_idx is None:
# If up_to_idx is None, delete all lookups
self.per_fqn_lookups = {}
else:
# lookups are sorted by idx.
up_to_idx = none_throws(up_to_idx)
for table_fqn, lookups in self.per_fqn_lookups.items():
# remove all lookups up to up_to_idx
self.per_fqn_lookups[table_fqn] = [
lookup for lookup in lookups if lookup.batch_idx >= up_to_idx
]

def compact(self, start_idx: int, end_idx: int) -> None:
r"""
Compact (ids, embeddings) in batch index range from start_idx, curr_batch_idx.
"""
assert (
start_idx < end_idx
), f"start_idx {start_idx} must be smaller then end_idx, but got {end_idx}"

new_per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
indexices = [h.batch_idx for h in lookups]
index_l = bisect_left(indexices, start_idx)
index_r = bisect_left(indexices, end_idx)
lookups_to_compact = lookups[index_l:index_r]
if len(lookups_to_compact) <= 1:
new_per_fqn_lookups[table_fqn] = lookups
continue
ids = [lookup.ids for lookup in lookups_to_compact]
embeddings = (
[none_throws(lookup.embeddings) for lookup in lookups_to_compact]
if self.embdUpdateMode != EmbdUpdateMode.NONE
else None
)
delta_rows = _compute_unique_rows(
ids=ids, embeddings=embeddings, mode=self.embdUpdateMode
)
new_per_fqn_lookups[table_fqn] = (
lookups[:index_l]
+ [
IndexedLookup(
batch_idx=start_idx,
ids=delta_rows.ids,
embeddings=delta_rows.embeddings,
)
]
+ lookups[index_r:]
)
self.per_fqn_lookups = new_per_fqn_lookups

def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
r"""
Return all unique/delta ids per table from the Delta Store.
"""

delta_per_table_fqn: Dict[str, DeltaRows] = {}
for table_fqn, lookups in self.per_fqn_lookups.items():
compact_ids = [
lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx
]
compact_embeddings = (
[
none_throws(lookup.embeddings)
for lookup in lookups
if lookup.batch_idx >= from_idx
]
if self.embdUpdateMode != EmbdUpdateMode.NONE
else None
)

delta_per_table_fqn[table_fqn] = _compute_unique_rows(
ids=compact_ids, embeddings=compact_embeddings, mode=self.embdUpdateMode
)
return delta_per_table_fqn
Loading
Loading