From f4530f1ac28f28582cb396b6a7c2b9ac52b8339c Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Sat, 7 Jun 2025 11:56:06 -0700 Subject: [PATCH] Delta Store (#3056) Summary: # Summary Introducing DeltaStore class which efficiently manages embedding table updates with the following features: * Tracks embedding table updates by table FQN with batch indexing * Supports multiple embedding update modes (NONE, FIRST, LAST) * Provides compaction functionality for calculating unique * Allows retrieval of unique/delta IDs per table with optional embedding values ## How lookups are preserved and fetched? In DeltaStore, lookups are preserved in the `per_fqn_lookups` dictionary, which maps table FQNs to lists of `IndexedLookup` objects. Each `IndexedLookup` contains: 1. `idx`: The batch index 2. `ids`: Tensor of embedding IDs 3. `embeddings`: Optional tensor of embedding values Lookups are added via the `append` method and can be: * Deleted with the `delete` method (up to a specific index) * Compacted with the `compact` method (merges lookups within a range) * Retrieved as unique/delta rows with the `get_delta` method ## This diffs: 1. delta_store.py includes all main logic to preserve, fetch, compact and delete 2. types.py includes required datatypes and enums 3. test_delta_store.py Includes test cases for compute, delete and compact methods Reviewed By: TroyGarden Differential Revision: D71130002 --- .../distributed/model_tracker/delta_store.py | 179 ++++ .../model_tracker/tests/test_delta_store.py | 834 ++++++++++++++++++ torchrec/distributed/model_tracker/types.py | 82 ++ 3 files changed, 1095 insertions(+) create mode 100644 torchrec/distributed/model_tracker/delta_store.py create mode 100644 torchrec/distributed/model_tracker/tests/test_delta_store.py create mode 100644 torchrec/distributed/model_tracker/types.py diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py new file mode 100644 index 000000000..315821154 --- /dev/null +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# pyre-strict +from bisect import bisect_left +from typing import Dict, List, Optional + +import torch +from torchrec.distributed.model_tracker.types import ( + DeltaRows, + EmbdUpdateMode, + IndexedLookup, +) +from torchrec.distributed.utils import none_throws + + +def _compute_unique_rows( + ids: List[torch.Tensor], + embeddings: Optional[List[torch.Tensor]], + mode: EmbdUpdateMode, +) -> DeltaRows: + r""" + To calculate unique ids and embeddings + """ + if mode == EmbdUpdateMode.NONE: + assert ( + embeddings is None + ), f"{mode=} == EmbdUpdateMode.NONE but received embeddings" + unique_ids = torch.cat(ids).unique(return_inverse=False) + return DeltaRows(ids=unique_ids, embeddings=None) + else: + assert ( + embeddings is not None + ), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings" + + cat_ids = torch.cat(ids) + cat_embeddings = torch.cat(embeddings) + + if mode == EmbdUpdateMode.LAST: + cat_ids = cat_ids.flip(dims=[0]) + cat_embeddings = cat_embeddings.flip(dims=[0]) + + # Get unique ids and inverse mapping (each element's index in unique_ids). + unique_ids, inverse = cat_ids.unique(sorted=False, return_inverse=True) + + # Create a tensor of original indices. This will be used to find first occurrences of ids. + all_indices = torch.arange(cat_ids.size(0), device=cat_ids.device) + + # Initialize tensor for first occurrence indices (filled with a high value). + first_occurrence = torch.full( + (unique_ids.size(0),), + cat_ids.size(0), + dtype=torch.int64, + device=cat_ids.device, + ) + + # Scatter indices using inverse mapping and reduce with "amin" to get first or last (if reversed) occurrence per unique id. + first_occurrence = first_occurrence.scatter_reduce( + 0, inverse, all_indices, reduce="amin" + ) + + # Use first occurrence indices to select corresponding embedding row. + unique_embedings = cat_embeddings[first_occurrence] + return DeltaRows(ids=unique_ids, embeddings=unique_embedings) + + +class DeltaStore: + """ + DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across + various batches during training, designed to be used with TorchRecs ModelDeltaTracker. + It maintains a CUDA in-memory representation of requested ids and embeddings/states, + providing a way to compact and get delta updates for each embedding table. + + The class supports different embedding update modes (NONE, FIRST, LAST) to determine + how to handle duplicate ids when compacting or retrieving embeddings. + + """ + + def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: + self.embdUpdateMode = embdUpdateMode + self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} + + def append( + self, + batch_idx: int, + table_fqn: str, + ids: torch.Tensor, + embeddings: Optional[torch.Tensor], + ) -> None: + table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, []) + table_fqn_lookup.append( + IndexedLookup(batch_idx=batch_idx, ids=ids, embeddings=embeddings) + ) + self.per_fqn_lookups[table_fqn] = table_fqn_lookup + + def delete(self, up_to_idx: Optional[int] = None) -> None: + """ + Delete all idx from the store up to `up_to_idx` + """ + if up_to_idx is None: + # If up_to_idx is None, delete all lookups + self.per_fqn_lookups = {} + else: + # lookups are sorted by idx. + up_to_idx = none_throws(up_to_idx) + for table_fqn, lookups in self.per_fqn_lookups.items(): + # remove all lookups up to up_to_idx + self.per_fqn_lookups[table_fqn] = [ + lookup for lookup in lookups if lookup.batch_idx >= up_to_idx + ] + + def compact(self, start_idx: int, end_idx: int) -> None: + r""" + Compact (ids, embeddings) in batch index range from start_idx, curr_batch_idx. + """ + assert ( + start_idx < end_idx + ), f"start_idx {start_idx} must be smaller then end_idx, but got {end_idx}" + + new_per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} + for table_fqn, lookups in self.per_fqn_lookups.items(): + indexices = [h.batch_idx for h in lookups] + index_l = bisect_left(indexices, start_idx) + index_r = bisect_left(indexices, end_idx) + lookups_to_compact = lookups[index_l:index_r] + if len(lookups_to_compact) <= 1: + new_per_fqn_lookups[table_fqn] = lookups + continue + ids = [lookup.ids for lookup in lookups_to_compact] + embeddings = ( + [none_throws(lookup.embeddings) for lookup in lookups_to_compact] + if self.embdUpdateMode != EmbdUpdateMode.NONE + else None + ) + delta_rows = _compute_unique_rows( + ids=ids, embeddings=embeddings, mode=self.embdUpdateMode + ) + new_per_fqn_lookups[table_fqn] = ( + lookups[:index_l] + + [ + IndexedLookup( + batch_idx=start_idx, + ids=delta_rows.ids, + embeddings=delta_rows.embeddings, + ) + ] + + lookups[index_r:] + ) + self.per_fqn_lookups = new_per_fqn_lookups + + def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: + r""" + Return all unique/delta ids per table from the Delta Store. + """ + + delta_per_table_fqn: Dict[str, DeltaRows] = {} + for table_fqn, lookups in self.per_fqn_lookups.items(): + compact_ids = [ + lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx + ] + compact_embeddings = ( + [ + none_throws(lookup.embeddings) + for lookup in lookups + if lookup.batch_idx >= from_idx + ] + if self.embdUpdateMode != EmbdUpdateMode.NONE + else None + ) + + delta_per_table_fqn[table_fqn] = _compute_unique_rows( + ids=compact_ids, embeddings=compact_embeddings, mode=self.embdUpdateMode + ) + return delta_per_table_fqn diff --git a/torchrec/distributed/model_tracker/tests/test_delta_store.py b/torchrec/distributed/model_tracker/tests/test_delta_store.py new file mode 100644 index 000000000..0be9fa20f --- /dev/null +++ b/torchrec/distributed/model_tracker/tests/test_delta_store.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + +from parameterized import parameterized +from torchrec.distributed.model_tracker.delta_store import ( + _compute_unique_rows, + DeltaStore, +) +from torchrec.distributed.model_tracker.types import ( + DeltaRows, + EmbdUpdateMode, + IndexedLookup, +) + + +class DeltaStoreTest(unittest.TestCase): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, methodName="runTest") -> None: + super().__init__(methodName) + + @dataclass + class AppendDeleteTestParams: + # input parameters + table_fqn_to_lookups: Dict[str, List[IndexedLookup]] + up_to_idx: Optional[int] + # expected output parameters + deleted_table_fqn_to_lookups: Dict[str, List[IndexedLookup]] + + @parameterized.expand( + [ + ( + "empty_lookups", + AppendDeleteTestParams( + table_fqn_to_lookups={}, + up_to_idx=None, + deleted_table_fqn_to_lookups={}, + ), + ), + ( + "delete_all_lookups", + AppendDeleteTestParams( + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1]), + embeddings=torch.tensor([1]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([2]), + embeddings=torch.tensor([2]), + ), + ] + }, + up_to_idx=None, + deleted_table_fqn_to_lookups={}, + ), + ), + ( + "single_table_with_idx_from_start", + AppendDeleteTestParams( + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1]), + embeddings=torch.tensor([1]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([2]), + embeddings=torch.tensor([2]), + ), + IndexedLookup( + batch_idx=3, + ids=torch.tensor([3]), + embeddings=torch.tensor([3]), + ), + IndexedLookup( + batch_idx=4, + ids=torch.tensor([4]), + embeddings=torch.tensor([4]), + ), + ] + }, + up_to_idx=3, + deleted_table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=3, + ids=torch.tensor([3]), + embeddings=torch.tensor([3]), + ), + IndexedLookup( + batch_idx=4, + ids=torch.tensor([4]), + embeddings=torch.tensor([4]), + ), + ] + }, + ), + ), + ( + "single_table_with_idx_x", + AppendDeleteTestParams( + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=8, + ids=torch.tensor([8]), + embeddings=torch.tensor([8]), + ), + IndexedLookup( + batch_idx=10, + ids=torch.tensor([10]), + embeddings=torch.tensor([10]), + ), + IndexedLookup( + batch_idx=13, + ids=torch.tensor([13]), + embeddings=torch.tensor([13]), + ), + ] + }, + up_to_idx=13, + deleted_table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=13, + ids=torch.tensor([13]), + embeddings=torch.tensor([13]), + ), + ] + }, + ), + ), + ( + "multi_table_with_idx_x", + AppendDeleteTestParams( + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=9, + ids=torch.tensor([9]), + embeddings=torch.tensor([9]), + ), + ], + "table_fqn_2": [ + IndexedLookup( + batch_idx=9, + ids=torch.tensor([9]), + embeddings=torch.tensor([9]), + ), + IndexedLookup( + batch_idx=10, + ids=torch.tensor([10]), + embeddings=torch.tensor([10]), + ), + ], + }, + up_to_idx=10, + deleted_table_fqn_to_lookups={ + "table_fqn_1": [], + "table_fqn_2": [ + IndexedLookup( + batch_idx=10, + ids=torch.tensor([10]), + embeddings=torch.tensor([10]), + ), + ], + }, + ), + ), + ] + ) + def test_append_and_delete( + self, _test_name: str, test_params: AppendDeleteTestParams + ) -> None: + delta_store = DeltaStore() + for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): + for lookup in lookup_list: + delta_store.append( + batch_idx=lookup.batch_idx, + table_fqn=table_fqn, + ids=lookup.ids, + embeddings=lookup.embeddings, + ) + # Before deletion, check that the lookups are as expected + self.assertEqual( + delta_store.per_fqn_lookups, + test_params.table_fqn_to_lookups, + ) + delta_store.delete(test_params.up_to_idx) + # After deletion, check that the lookups are as expected + self.assertEqual( + delta_store.per_fqn_lookups, + test_params.deleted_table_fqn_to_lookups, + ) + + @dataclass + class ComputeTestParams: + # input parameters + ids: List[torch.Tensor] + embeddings: Optional[List[torch.Tensor]] + embdUpdateMode: EmbdUpdateMode + # expected output parameters + expected_output: DeltaRows + expect_assert: bool + + @parameterized.expand( + [ + # test cases for EmbdUpdateMode.NONE + ( + "unique_ids", + ComputeTestParams( + ids=[ + torch.tensor([1, 2, 3, 4, 5]), + torch.tensor([6, 7, 8, 9, 10]), + ], + embeddings=None, + embdUpdateMode=EmbdUpdateMode.NONE, + expected_output=DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + embeddings=None, + ), + expect_assert=False, + ), + ), + ( + "duplicate_ids", + ComputeTestParams( + ids=[ + torch.tensor([4, 1, 3, 6, 5, 2]), + torch.tensor([2, 10, 8, 4, 9, 7]), + ], + embeddings=None, + embdUpdateMode=EmbdUpdateMode.NONE, + expected_output=DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + embeddings=None, + ), + expect_assert=False, + ), + ), + # test case for EmbdUpdateMode.NONE with embeddings (should assert) + ( + "none_mode_with_embeddings", + ComputeTestParams( + ids=[ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + ], + embeddings=[ + torch.tensor([[1.0], [2.0], [3.0]]), + torch.tensor([[4.0], [5.0], [6.0]]), + ], + embdUpdateMode=EmbdUpdateMode.NONE, + expected_output=DeltaRows( + ids=torch.tensor([]), + embeddings=None, + ), + expect_assert=True, + ), + ), + # test cases for EmbdUpdateMode.FIRST + ( + "first_mode_without_embeddings", + ComputeTestParams( + ids=[ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + ], + embeddings=None, + embdUpdateMode=EmbdUpdateMode.FIRST, + expected_output=DeltaRows( + ids=torch.tensor([]), + embeddings=None, + ), + expect_assert=True, + ), + ), + ( + "first_mode_unique_ids", + ComputeTestParams( + ids=[ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + ], + embeddings=[ + torch.tensor([[1.0], [2.0], [3.0]]), + torch.tensor([[4.0], [5.0], [6.0]]), + ], + embdUpdateMode=EmbdUpdateMode.FIRST, + expected_output=DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6]), + embeddings=torch.tensor( + [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + ), + ), + expect_assert=False, + ), + ), + ( + "first_mode_duplicate_ids", + ComputeTestParams( + ids=[ + torch.tensor([4, 1, 3, 6, 5, 2]), + torch.tensor([2, 10, 8, 4, 9, 7]), + ], + embeddings=[ + torch.tensor([[40.0], [10.0], [30.0], [60.0], [50.0], [20.0]]), + torch.tensor([[25.0], [100.0], [80.0], [45.0], [90.0], [70.0]]), + ], + embdUpdateMode=EmbdUpdateMode.FIRST, + expected_output=DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + # First occurrence of each ID is kept + embeddings=torch.tensor( + [ + [10.0], + [20.0], + [30.0], + [40.0], + [50.0], + [60.0], + [70.0], + [80.0], + [90.0], + [100.0], + ] + ), + ), + expect_assert=False, + ), + ), + # test cases for EmbdUpdateMode.LAST + ( + "last_mode_without_embeddings", + ComputeTestParams( + ids=[ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + ], + embeddings=None, + embdUpdateMode=EmbdUpdateMode.LAST, + expected_output=DeltaRows( + ids=torch.tensor([]), + embeddings=None, + ), + expect_assert=True, + ), + ), + ( + "last_mode_unique_ids", + ComputeTestParams( + ids=[ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + ], + embeddings=[ + torch.tensor([[1.0], [2.0], [3.0]]), + torch.tensor([[4.0], [5.0], [6.0]]), + ], + embdUpdateMode=EmbdUpdateMode.LAST, + expected_output=DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6]), + embeddings=torch.tensor( + [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]] + ), + ), + expect_assert=False, + ), + ), + ( + "last_mode_duplicate_ids", + ComputeTestParams( + ids=[ + torch.tensor([4, 1, 3, 6, 5, 2]), + torch.tensor([2, 10, 8, 4, 9, 7]), + ], + embeddings=[ + torch.tensor([[40.0], [10.0], [30.0], [60.0], [50.0], [20.0]]), + torch.tensor([[25.0], [100.0], [80.0], [45.0], [90.0], [70.0]]), + ], + embdUpdateMode=EmbdUpdateMode.LAST, + expected_output=DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + # Last occurrence of each ID is kept + embeddings=torch.tensor( + [ + [10.0], + [25.0], + [30.0], + [45.0], + [50.0], + [60.0], + [70.0], + [80.0], + [90.0], + [100.0], + ] + ), + ), + expect_assert=False, + ), + ), + ] + ) + def test_compute_unique_rows( + self, _test_name: str, test_params: ComputeTestParams + ) -> None: + if test_params.expect_assert: + # If we expect an assertion error, check that it's raised + with self.assertRaises(AssertionError): + _compute_unique_rows( + test_params.ids, test_params.embeddings, test_params.embdUpdateMode + ) + else: + # Otherwise, proceed with the normal test + result = _compute_unique_rows( + test_params.ids, test_params.embeddings, test_params.embdUpdateMode + ) + + self.assertTrue(torch.equal(result.ids, test_params.expected_output.ids)) + self.assertTrue( + torch.equal( + ( + result.embeddings + if result.embeddings is not None + else torch.empty(0) + ), + ( + test_params.expected_output.embeddings + if test_params.expected_output.embeddings is not None + else torch.empty(0) + ), + ) + ) + + @dataclass + class CompactTestParams: + # input parameters + embdUpdateMode: EmbdUpdateMode + table_fqn_to_lookups: Dict[str, List[IndexedLookup]] + start_idx: int + end_idx: int + # expected output parameters + expected_delta: Dict[str, DeltaRows] + expect_assert: bool = False + + @parameterized.expand( + [ + # Test case for compaction with EmbdUpdateMode.NONE + ( + "empty_lookups", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.NONE, + table_fqn_to_lookups={}, + start_idx=1, + end_idx=5, + expected_delta={}, + ), + ), + ( + "single_lookup_no_compaction", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.NONE, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=3, + ids=torch.tensor([1, 2, 3]), + embeddings=None, + ), + ] + }, + start_idx=1, + end_idx=5, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3]), + embeddings=None, + ), + }, + ), + ), + ( + "multi_lookup_all_unique", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.NONE, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=None, + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([4, 5, 6]), + embeddings=None, + ), + IndexedLookup( + batch_idx=3, + ids=torch.tensor([7, 8, 9]), + embeddings=None, + ), + ] + }, + start_idx=1, + end_idx=3, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), + embeddings=None, + ), + }, + ), + ), + ( + "multi_lookup_with_duplicates", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.NONE, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=None, + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([3, 4, 5]), + embeddings=None, + ), + IndexedLookup( + batch_idx=3, + ids=torch.tensor([5, 6, 7]), + embeddings=None, + ), + IndexedLookup( + batch_idx=4, + ids=torch.tensor([7, 8, 9]), + embeddings=None, + ), + ] + }, + start_idx=1, + end_idx=4, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), + embeddings=None, + ), + }, + ), + ), + # Test case for compaction with EmbdUpdateMode.FIRST + ( + "multi_lookup_with_duplicates_first_mode", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.FIRST, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=torch.tensor([[10.0], [20.0], [30.0]]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([3, 4, 5]), + embeddings=torch.tensor([[35.0], [40.0], [50.0]]), + ), + IndexedLookup( + batch_idx=3, + ids=torch.tensor([5, 6, 7]), + embeddings=torch.tensor([[55.0], [60.0], [70.0]]), + ), + IndexedLookup( + batch_idx=4, + ids=torch.tensor([7, 8, 9]), + embeddings=torch.tensor([[75.0], [80.0], [90.0]]), + ), + ] + }, + start_idx=1, + end_idx=4, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), + embeddings=torch.tensor( + [ + [10.0], + [20.0], + [30.0], + [40.0], + [50.0], + [60.0], + [70.0], + [80.0], + [90.0], + ] + ), + ), + }, + ), + ), + ( + "multiple_tables_first_mode", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.FIRST, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=torch.tensor([[10.0], [20.0], [30.0]]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([3, 4, 5]), + embeddings=torch.tensor([[35.0], [40.0], [50.0]]), + ), + ], + "table_fqn_2": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([10, 20, 30]), + embeddings=torch.tensor([[100.0], [200.0], [300.0]]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([30, 40, 50]), + embeddings=torch.tensor([[350.0], [400.0], [500.0]]), + ), + ], + }, + start_idx=1, + end_idx=3, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5]), + embeddings=torch.tensor( + [[10.0], [20.0], [30.0], [40.0], [50.0]] + ), + ), + "table_fqn_2": DeltaRows( + ids=torch.tensor([10, 20, 30, 40, 50]), + embeddings=torch.tensor( + [[100.0], [200.0], [300.0], [400.0], [500.0]] + ), + ), + }, + ), + ), + # Test case for compaction with EmbdUpdateMode.LAST + ( + "multi_lookup_with_duplicates_last_mode", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.LAST, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=torch.tensor([[10.0], [20.0], [30.0]]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([3, 4, 5]), + embeddings=torch.tensor([[35.0], [40.0], [50.0]]), + ), + IndexedLookup( + batch_idx=3, + ids=torch.tensor([5, 6, 7]), + embeddings=torch.tensor([[55.0], [60.0], [70.0]]), + ), + IndexedLookup( + batch_idx=4, + ids=torch.tensor([7, 8, 9]), + embeddings=torch.tensor([[75.0], [80.0], [90.0]]), + ), + ] + }, + start_idx=1, + end_idx=4, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), + embeddings=torch.tensor( + [ + [10.0], + [20.0], + [35.0], + [40.0], + [55.0], + [60.0], + [75.0], + [80.0], + [90.0], + ] + ), + ), + }, + ), + ), + ( + "multiple_tables_last_mode", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.LAST, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=torch.tensor([[10.0], [20.0], [30.0]]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([3, 4, 5]), + embeddings=torch.tensor([[35.0], [40.0], [50.0]]), + ), + ], + "table_fqn_2": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([10, 20, 30]), + embeddings=torch.tensor([[100.0], [200.0], [300.0]]), + ), + IndexedLookup( + batch_idx=2, + ids=torch.tensor([30, 40, 50]), + embeddings=torch.tensor([[350.0], [400.0], [500.0]]), + ), + ], + }, + start_idx=1, + end_idx=3, + expected_delta={ + "table_fqn_1": DeltaRows( + ids=torch.tensor([1, 2, 3, 4, 5]), + embeddings=torch.tensor( + [[10.0], [20.0], [35.0], [40.0], [50.0]] + ), + ), + "table_fqn_2": DeltaRows( + ids=torch.tensor([10, 20, 30, 40, 50]), + embeddings=torch.tensor( + [[100.0], [200.0], [350.0], [400.0], [500.0]] + ), + ), + }, + ), + ), + # Test case for invalid start_idx and end_idx + ( + "invalid_indices", + CompactTestParams( + embdUpdateMode=EmbdUpdateMode.NONE, + table_fqn_to_lookups={ + "table_fqn_1": [ + IndexedLookup( + batch_idx=1, + ids=torch.tensor([1, 2, 3]), + embeddings=None, + ), + ] + }, + start_idx=5, + end_idx=3, + expected_delta={}, + expect_assert=True, + ), + ), + ] + ) + def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None: + """ + Test the compact method of DeltaStore. + """ + # Create a DeltaStore with the specified embdUpdateMode + delta_store = DeltaStore(embdUpdateMode=test_params.embdUpdateMode) + + # Populate the DeltaStore with the test lookups + for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items(): + for lookup in lookup_list: + delta_store.append( + batch_idx=lookup.batch_idx, + table_fqn=table_fqn, + ids=lookup.ids, + embeddings=lookup.embeddings, + ) + if test_params.expect_assert: + # If we expect an assertion error, check that it's raised + with self.assertRaises(AssertionError): + delta_store.compact( + start_idx=test_params.start_idx, end_idx=test_params.end_idx + ) + else: + # Call the compact method + delta_store.compact( + start_idx=test_params.start_idx, end_idx=test_params.end_idx + ) + # Verify the result using get_delta method + delta_result = delta_store.get_delta() + + # compare all fqns in the result + for table_fqn, delta_rows in test_params.expected_delta.items(): + # Comparing ids + self.assertTrue(delta_result[table_fqn].ids.allclose(delta_rows.ids)) + # Comparing embeddings + if ( + delta_rows.embeddings is not None + and delta_result[table_fqn].embeddings is not None + ): + self.assertTrue( + # pyre-ignore + delta_result[table_fqn].embeddings.allclose( + delta_rows.embeddings + ) + ) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py new file mode 100644 index 000000000..c92b93732 --- /dev/null +++ b/torchrec/distributed/model_tracker/types.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import torch + + +@dataclass +class IndexedLookup: + r""" + Data class for storing per batch lookedup ids and embeddings. + """ + + batch_idx: int + ids: torch.Tensor + embeddings: Optional[torch.Tensor] + + +@dataclass +class DeltaRows: + r""" + Data class as an interface for returning and storing compacted ids and embeddings. + compact(List[IndexedLookup]) -> DeltaRows + """ + + ids: torch.Tensor + embeddings: Optional[torch.Tensor] + + +class TrackingMode(Enum): + r""" + Tracking mode for ``ModelDeltaTracker``. + + Enums: + ID_ONLY: Tracks row IDs only, providing a lightweight option for monitoring. + EMBEDDING: Tracks both row IDs and their corresponding embedding values, + enabling precise top-k result calculations. However, this option comes with increased memory usage. + """ + + ID_ONLY = "id_only" + EMBEDDING = "embedding" + + +class EmbdUpdateMode(Enum): + r""" + To identify which embedding value to store while tracking. + + Enums: + NONE: Used for id only mode when we aren't tracking the embeddings. + FIRST: Stores the earlier embedding value for each id. Useful for checkpoint/snapshot. + LAST: Stores the latest embedding value for each id. Used for some opmtimizer state modes. + """ + + NONE = "none" + FIRST = "first" + LAST = "last" + + +@dataclass +class DeltaTrackerConfig: + r""" + Configuration for ``ModelDeltaTracker``. + + Args: + tracking_mode (TrackingMode): tracking mode for the delta tracker. + consumers (Optional[List[str]]): list of consumers for the delta tracker. + delete_on_read (bool): whether to delete the compacted data after get_delta method is called. + + """ + + tracking_mode: TrackingMode + consumers: Optional[List[str]] = None + delete_on_read: bool = True