From af119265b631006341ad3945a00b6abc9cac2305 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Fri, 28 Mar 2025 07:05:58 -0700 Subject: [PATCH] Multi forward MCH eviction fix (#2836) Summary: ## Issue: Direct tensor modification during training with multiple forward passes breaks PyTorch's autograd graph, causing "one of the variables needed for gradient computation has been modified by an inplace operation" runtime error. ## Solution: Use in-place updates with .data accessor to safely reinitialize evicted embeddings without invalidating gradient computation. Reviewed By: dstaay-fb Differential Revision: D71491003 --- torchrec/distributed/mc_embedding_modules.py | 23 +- .../distributed/tests/test_mc_embedding.py | 265 +++++++++++++++ .../distributed/tests/test_mc_embeddingbag.py | 304 ++++++++++++------ torchrec/modules/mc_embedding_modules.py | 24 +- 4 files changed, 502 insertions(+), 114 deletions(-) diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index e66cb674b..b817f020a 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -129,6 +129,9 @@ def __init__( ) ) self._return_remapped_features: bool = module._return_remapped_features + self._allow_in_place_embed_weight_update: bool = ( + module._allow_in_place_embed_weight_update + ) # pyre-ignore self._table_to_tbe_and_index = {} @@ -202,12 +205,22 @@ def _evict(self, evictions_per_table: Dict[str, Optional[torch.Tensor]]) -> None init_fn = self._embedding_module._table_name_to_config[ table ].init_fn - # Set evicted indices to original init_fn instead of all zeros - # pyre-ignore [29] - table_weight_param[evictions_indices_for_table] = init_fn( - table_weight_param[evictions_indices_for_table] - ) + if self._allow_in_place_embed_weight_update: + # In-place update with .data to bypass PyTorch's autograd tracking. + # This is required for model training with multiple forward passes where the autograd graph + # is already created. Direct tensor modification would trigger PyTorch's in-place operation + # checks and invalidate gradients, while .data allows safe reinitialization of evicted + # embeddings without affecting the computational graph. + # pyre-ignore [29] + table_weight_param.data[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) + else: + # pyre-ignore [29] + table_weight_param[evictions_indices_for_table] = init_fn( + table_weight_param[evictions_indices_for_table] + ) def compute( self, diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 60de369d1..64c3ca14e 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -59,6 +59,7 @@ def __init__( device: torch.device, return_remapped: bool = False, input_hash_size: int = 4000, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__() self._return_remapped = return_remapped @@ -91,6 +92,7 @@ def __init__( embedding_configs=tables, ), return_remapped_features=self._return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, ) ) @@ -242,6 +244,106 @@ def _test_sharding_and_remapping( # noqa C901 # TODO: validate embedding rows, and eviction +def _test_in_place_embd_weight_update( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + initial_state_per_rank: List[Dict[str, torch.Tensor]], + final_state_per_rank: List[Dict[str, torch.Tensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + input_hash_size: int = 4000, + allow_in_place_embed_weight_update: bool = True, +) -> None: + + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + input_hash_size=input_hash_size, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ec._embedding_collection.embeddings["table_0"].weight, + sparse_arch._mc_ec._embedding_collection.embeddings["table_1"].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ec, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ec": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + initial_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in initial_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in initial_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, initial_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {initial_state_per_rank[rank][postfix]}" + + sharded_sparse_arch.load_state_dict(initial_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + + if not allow_in_place_embed_weight_update: + # Without in-place overwrite the backward pass will fail due to tensor version mismatch + with unittest.TestCase().assertRaisesRegex( + RuntimeError, + "one of the variables needed for gradient computation has been modified by an inplace operation", + ): + loss1.backward() + else: + loss1.backward() + loss2.backward() + final_state_dict = sharded_sparse_arch.state_dict() + for key, sharded_tensor in final_state_dict.items(): + postfix = ".".join(key.split(".")[-2:]) + if postfix in final_state_per_rank[ctx.rank]: + tensor = sharded_tensor.local_shards()[0].tensor.cpu() + assert torch.equal( + tensor, final_state_per_rank[ctx.rank][postfix] + ), f"initial state {key} on {ctx.rank} does not match, got {tensor}, expect {final_state_per_rank[rank][postfix]}" + + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + def _test_sharding_and_resharding( # noqa C901 tables: List[EmbeddingConfig], rank: int, @@ -1016,3 +1118,166 @@ def test_sharding_zch_mc_ec_dedup_input_error(self, backend: str) -> None: ), except AssertionError as e: self.assertTrue("0 != 1" in str(e)) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + backend=st.sampled_from(["nccl"]), + allow_in_place_embed_weight_update=st.booleans(), + ) + @settings(deadline=None) + def test_in_place_embd_weight_update( + self, backend: str, allow_in_place_embed_weight_update: bool + ) -> None: + + WORLD_SIZE = 2 + + embedding_config = [ + EmbeddingConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), + ] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 2, + 2, + 2, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + # TODO: cleanup sorting so more dedugable/logical initial fill + + kjt_out_per_iter_per_rank.append( + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ] + ) + + initial_state_per_rank = [ + { + "table_0._mch_remapped_ids_mapping": torch.arange(8, dtype=torch.int64), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_remapped_ids_mapping": torch.arange( + start=8, end=16, dtype=torch.int64 + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + start=16, end=32, dtype=torch.int64 + ), + }, + ] + max_int = torch.iinfo(torch.int64).max + + final_state_per_rank = [ + { + "table_0._mch_sorted_raw_ids": torch.LongTensor( + [1000, 1001, 1002, 1004] + [max_int] * 4 + ), + "table_1._mch_sorted_raw_ids": torch.LongTensor([max_int] * 16), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [3, 4, 5, 6, 0, 1, 2, 7] + ), + "table_1._mch_remapped_ids_mapping": torch.arange( + 16, dtype=torch.int64 + ), + }, + { + "table_0._mch_sorted_raw_ids": torch.LongTensor([2000] + [max_int] * 7), + "table_1._mch_sorted_raw_ids": torch.LongTensor( + [2000, 2001, 2002, 2004] + [max_int] * 12 + ), + "table_0._mch_remapped_ids_mapping": torch.LongTensor( + [14, 8, 9, 10, 11, 12, 13, 15] + ), + "table_1._mch_remapped_ids_mapping": torch.LongTensor( + [27, 29, 28, 30, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 31] + ), + }, + ] + + self._run_multi_process_test( + callable=_test_in_place_embd_weight_update, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + initial_state_per_rank=initial_state_per_rank, + final_state_per_rank=final_state_per_rank, + sharder=ManagedCollisionEmbeddingCollectionSharder(), + backend=backend, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) diff --git a/torchrec/distributed/tests/test_mc_embeddingbag.py b/torchrec/distributed/tests/test_mc_embeddingbag.py index e891e8841..a24caf2cc 100644 --- a/torchrec/distributed/tests/test_mc_embeddingbag.py +++ b/torchrec/distributed/tests/test_mc_embeddingbag.py @@ -9,7 +9,7 @@ import copy import unittest -from typing import Dict, List, Optional, Tuple +from typing import Dict, Final, List, Optional, Tuple import torch import torch.nn as nn @@ -43,12 +43,103 @@ from torchrec.test_utils import skip_if_asan_class +# Global constants for testing ShardedManagedCollisionEmbeddingBagCollection + +WORLD_SIZE = 2 + +# Input KeyedJaggedTensors for each rank in distributed tests +embedding_bag_config: Final[List[EmbeddingBagConfig]] = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=8, + num_embeddings=16, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=8, + num_embeddings=32, + ), +] + +# Expected remapped outputs per iteration per rank for validation +kjt_input_per_rank: Final[List[KeyedJaggedTensor]] = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.LongTensor( + [ + 1000, + 1002, + 1004, + 2000, + 2002, + 2004, + 1, + 1, + 1, + ], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), + weights=None, + ), +] + +kjt_out_per_iter_per_rank: Final[List[List[KeyedJaggedTensor]]] = [ + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 15, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [7, 7, 7, 31, 31, 31], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ], + [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 14, 4, 27, 29, 28], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor( + [3, 5, 6, 27, 28, 30], + ), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), + weights=None, + ), + ], +] + + class SparseArch(nn.Module): def __init__( self, tables: List[EmbeddingBagConfig], device: torch.device, return_remapped: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__() self._return_remapped = return_remapped @@ -81,6 +172,7 @@ def __init__( embedding_configs=tables, ), return_remapped_features=self._return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, ) ) @@ -268,6 +360,87 @@ def _test_sharding_and_remapping( # noqa C901 # TODO: validate embedding rows, and eviction +def _test_in_place_embd_weight_update( # noqa C901 + output_keys: List[str], + tables: List[EmbeddingBagConfig], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]], + sharder: ModuleSharder[nn.Module], + backend: str, + local_size: Optional[int] = None, + allow_in_place_embed_weight_update: bool = True, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input = kjt_input_per_rank[rank].to(ctx.device) + kjt_out_per_iter = [ + kjt[rank].to(ctx.device) for kjt in kjt_out_per_iter_per_rank + ] + return_remapped: bool = True + sparse_arch = SparseArch( + tables, + torch.device("meta"), + return_remapped=return_remapped, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, + ) + apply_optimizer_in_backward( + RowWiseAdagrad, + [ + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_0" + ].weight, + sparse_arch._mc_ebc._embedding_bag_collection.embedding_bags[ + "table_1" + ].weight, + ], + {"lr": 0.01}, + ) + module_sharding_plan = construct_module_sharding_plan( + sparse_arch._mc_ebc, + per_param_sharding={"table_0": row_wise(), "table_1": row_wise()}, + local_size=local_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + sharder=sharder, + ) + + sharded_sparse_arch = _shard_modules( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"_mc_ebc": module_sharding_plan}), + # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + env=ShardingEnv.from_process_group(ctx.pg), + sharders=[sharder], + device=ctx.device, + ) + + test_state_dict = sharded_sparse_arch.state_dict() + sharded_sparse_arch.load_state_dict(test_state_dict) + + # sharded model + # each rank gets a subbatch + loss1, remapped_ids1 = sharded_sparse_arch(kjt_input) + loss2, remapped_ids2 = sharded_sparse_arch(kjt_input) + if not allow_in_place_embed_weight_update: + # Without in-place overwrite the backward pass will fail due to tensor version mismatch + with unittest.TestCase().assertRaisesRegex( + RuntimeError, + "one of the variables needed for gradient computation has been modified by an inplace operation", + ): + loss1.backward() + else: + loss1.backward() + loss2.backward() + remapped_ids = [remapped_ids1, remapped_ids2] + for key in output_keys: + for i, kjt_out in enumerate(kjt_out_per_iter): + assert torch.equal( + remapped_ids[i][key].values(), + kjt_out[key].values(), + ), f"feature {key} on {ctx.rank} iteration {i} does not match, got {remapped_ids[i][key].values()}, expect {kjt_out[key].values()}" + + @skip_if_asan_class class ShardedMCEmbeddingBagCollectionParallelTest(MultiProcessTestBase): @unittest.skipIf( @@ -311,22 +484,6 @@ def test_uneven_sharding(self, backend: str) -> None: @given(backend=st.sampled_from(["nccl"])) @settings(deadline=None) def test_even_sharding(self, backend: str) -> None: - WORLD_SIZE = 2 - - embedding_bag_config = [ - EmbeddingBagConfig( - name="table_0", - feature_names=["feature_0"], - embedding_dim=8, - num_embeddings=16, - ), - EmbeddingBagConfig( - name="table_1", - feature_names=["feature_1"], - embedding_dim=8, - num_embeddings=32, - ), - ] self._run_multi_process_test( callable=_test_sharding, @@ -344,99 +501,33 @@ def test_even_sharding(self, backend: str) -> None: @given(backend=st.sampled_from(["nccl"])) @settings(deadline=None) def test_sharding_zch_mc_ebc(self, backend: str) -> None: - - WORLD_SIZE = 2 - - embedding_bag_config = [ - EmbeddingBagConfig( - name="table_0", - feature_names=["feature_0"], - embedding_dim=8, - num_embeddings=16, - ), - EmbeddingBagConfig( - name="table_1", - feature_names=["feature_1"], - embedding_dim=8, - num_embeddings=32, - ), - ] - - kjt_input_per_rank = [ # noqa - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1", "feature_2"], - values=torch.LongTensor( - [1000, 2000, 1001, 2000, 2001, 2002, 1, 1, 1], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), - weights=None, - ), - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1", "feature_2"], - values=torch.LongTensor( - [ - 1000, - 1002, - 1004, - 2000, - 2002, - 2004, - 1, - 1, - 1, - ], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]), - weights=None, - ), - ] - - kjt_out_per_iter_per_rank: List[List[KeyedJaggedTensor]] = [] - kjt_out_per_iter_per_rank.append( - [ - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [7, 15, 7, 31, 31, 31], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [7, 7, 7, 31, 31, 31], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - ] + self._run_multi_process_test( + callable=_test_sharding_and_remapping, + output_keys=["feature_0", "feature_1"], + world_size=WORLD_SIZE, + tables=embedding_bag_config, + kjt_input_per_rank=kjt_input_per_rank, + kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, + sharder=ManagedCollisionEmbeddingBagCollectionSharder(), + backend=backend, ) - # TODO: cleanup sorting so more dedugable/logical initial fill - kjt_out_per_iter_per_rank.append( - [ - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [3, 14, 4, 27, 29, 28], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - KeyedJaggedTensor.from_lengths_sync( - keys=["feature_0", "feature_1"], - values=torch.LongTensor( - [3, 5, 6, 27, 28, 30], - ), - lengths=torch.LongTensor([1, 1, 1, 1, 1, 1]), - weights=None, - ), - ] - ) + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-ignore + @given( + backend=st.sampled_from(["nccl"]), + allow_in_place_embed_weight_update=st.booleans(), + ) + @settings(deadline=None) + def test_in_place_embd_weight_update( + self, backend: str, allow_in_place_embed_weight_update: bool + ) -> None: self._run_multi_process_test( - callable=_test_sharding_and_remapping, + callable=_test_in_place_embd_weight_update, output_keys=["feature_0", "feature_1"], world_size=WORLD_SIZE, tables=embedding_bag_config, @@ -444,4 +535,5 @@ def test_sharding_zch_mc_ebc(self, backend: str) -> None: kjt_out_per_iter_per_rank=kjt_out_per_iter_per_rank, sharder=ManagedCollisionEmbeddingBagCollectionSharder(), backend=backend, + allow_in_place_embed_weight_update=allow_in_place_embed_weight_update, ) diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 834ad667c..6e7850dba 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -39,6 +39,9 @@ class BaseManagedCollisionEmbeddingCollection(nn.Module): managed_collision_modules: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings + allow_in_place_embed_weight_update(bool): Enables in-place update of embedding + weights on eviction. When enabled, this flag allows updates to embedding + weights without modifying the autograd graph. """ @@ -47,10 +50,12 @@ def __init__( embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection], managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__() self._managed_collision_collection = managed_collision_collection self._return_remapped_features = return_remapped_features + self._allow_in_place_embed_weight_update = allow_in_place_embed_weight_update self._embedding_module: Union[EmbeddingBagCollection, EmbeddingCollection] = ( embedding_module ) @@ -97,10 +102,13 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio For details of input and output types, see EmbeddingCollection Args: - embedding_module: EmbeddingCollection to lookup embeddings - managed_collision_modules: Dict of managed collision modules + embedding_collection: EmbeddingCollection to lookup embeddings + managed_collision_collection: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings + allow_in_place_embed_weight_update(bool): enable in place update of embedding + weights on evict. This flag when enabled will allow update embedding + weights without modifying of autograd graph. """ @@ -109,9 +117,13 @@ def __init__( embedding_collection: EmbeddingCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__( - embedding_collection, managed_collision_collection, return_remapped_features + embedding_collection, + managed_collision_collection, + return_remapped_features, + allow_in_place_embed_weight_update, ) # For consistency with embedding bag collection @@ -132,6 +144,10 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec managed_collision_modules: Dict of managed collision modules return_remapped_features (bool): whether to return remapped input features in addition to embeddings + allow_in_place_embed_weight_update(bool): Enables in-place update of embedding + weights on eviction. When enabled, this flag allows updates to embedding + weights without modifying the autograd graph. + """ @@ -140,11 +156,13 @@ def __init__( embedding_bag_collection: EmbeddingBagCollection, managed_collision_collection: ManagedCollisionCollection, return_remapped_features: bool = False, + allow_in_place_embed_weight_update: bool = False, ) -> None: super().__init__( embedding_bag_collection, managed_collision_collection, return_remapped_features, + allow_in_place_embed_weight_update, ) # For backwards compat, as references existed in tests