diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 1c114a2d3..102ed2060 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -855,6 +855,12 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no """ This provides consistency between this class and the EmbeddingBagCollection's nn.Module API calls (state_dict, named_modules, etc) + + Args: + skip_registering (bool): If True, skips registering state_dict hooks. This is useful + for dynamic sharding where the state_dict hooks do not need to be + reregistered when being resharded. Default is False. + """ self.embedding_bags: nn.ModuleDict = nn.ModuleDict() for table_name in self._table_names: diff --git a/torchrec/distributed/test_utils/test_input.py b/torchrec/distributed/test_utils/test_input.py index 0ac44ae80..03609317b 100644 --- a/torchrec/distributed/test_utils/test_input.py +++ b/torchrec/distributed/test_utils/test_input.py @@ -53,6 +53,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": ) def record_stream(self, stream: torch.Stream) -> None: + """ + need to explicitly call `record_stream` for non-pytorch native object (KJT) + """ self.float_features.record_stream(stream) if isinstance(self.idlist_features, KeyedJaggedTensor): self.idlist_features.record_stream(stream) @@ -204,7 +207,7 @@ def generate_local_batches( all_zeros: bool = False, ) -> List["ModelInput"]: """ - Returns multi-rank batches of world_size + Returns multi-rank batches (ModelInput) of world_size """ return [ cls.generate( @@ -255,7 +258,7 @@ def generate( all_zeros: bool = False, ) -> "ModelInput": """ - Returns a single batch + Returns a single batch of `ModelInput` """ float_features = ( torch.zeros((batch_size, num_float_features), device=device) @@ -263,7 +266,7 @@ def generate( else torch.rand((batch_size, num_float_features), device=device) ) idlist_features = ( - ModelInput._create_standard_kjt( + ModelInput.create_standard_kjt( batch_size=batch_size, tables=tables, pooling_avg=pooling_avg, @@ -281,7 +284,7 @@ def generate( else None ) idscore_features = ( - ModelInput._create_standard_kjt( + ModelInput.create_standard_kjt( batch_size=batch_size, tables=weighted_tables, pooling_avg=pooling_avg, @@ -324,6 +327,13 @@ def _create_features_lengths_indices( lengths_dtype: torch.dtype = torch.int64, all_zeros: bool = False, ) -> Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]: + """ + Create keys, lengths, and indices for a KeyedJaggedTensor from embedding table configs. + + Returns: + Tuple[List[str], List[torch.Tensor], List[torch.Tensor]]: + Feature names, per-feature lengths, and per-feature indices. + """ pooling_factor_per_feature: List[int] = [] num_embeddings_per_feature: List[int] = [] max_length_per_feature: List[Optional[int]] = [] @@ -395,6 +405,14 @@ def _assemble_kjt( use_offsets: bool = False, offsets_dtype: torch.dtype = torch.int64, ) -> KeyedJaggedTensor: + """ + + Assembles a KeyedJaggedTensor (KJT) from the provided per-feature lengths and indices. + + This method is used to generate corresponding local_batches and global_batch KJTs. + It concatenates the lengths and indices for each feature to form a complete KJT. + """ + lengths = torch.cat(lengths_per_feature) indices = torch.cat(indices_per_feature) offsets = None @@ -407,7 +425,7 @@ def _assemble_kjt( return KeyedJaggedTensor(features, indices, weights, lengths, offsets) @staticmethod - def _create_standard_kjt( + def create_standard_kjt( batch_size: int, tables: Union[ List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig] @@ -464,6 +482,10 @@ def _create_batched_standard_kjts( lengths_dtype: torch.dtype = torch.int64, all_zeros: bool = False, ) -> Tuple[KeyedJaggedTensor, List[KeyedJaggedTensor]]: + """ + generate a global KJT and corresponding per-rank KJTs, the data are the same + so that they can be used for result comparison. + """ data_per_rank = [ ModelInput._create_features_lengths_indices( batch_size, diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index ccc46cc94..021ecc1a1 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -35,6 +35,7 @@ MultiProcessContext, MultiProcessTestBase, ) +from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.distributed.test_utils.test_sharding import copy_state_dict from torchrec.distributed.types import ( @@ -58,46 +59,6 @@ def feature_name(i: int) -> str: return "feature_" + str(i) -def generate_input_by_world_size( - world_size: int, - num_tables: int, - num_embeddings: int = 4, - max_mul: int = 3, -) -> List[KeyedJaggedTensor]: - # TODO merge with new ModelInput generator in TestUtils - kjt_input_per_rank = [] - mul = random.randint(1, max_mul) - total_size = num_tables * mul - - for _ in range(world_size): - feature_names = [feature_name(i) for i in range(num_tables)] - lengths = [] - values = [] - counting_l = 0 - for i in range(total_size): - if i == total_size - 1: - lengths.append(total_size - counting_l) - break - next_l = random.randint(0, total_size - counting_l) - values.extend( - [random.randint(0, num_embeddings - 1) for _ in range(next_l)] - ) - lengths.append(next_l) - counting_l += next_l - - # for length in lengths: - - kjt_input_per_rank.append( - KeyedJaggedTensor.from_lengths_sync( - keys=feature_names, - values=torch.LongTensor(values), - lengths=torch.LongTensor(lengths), - ) - ) - - return kjt_input_per_rank - - def generate_embedding_bag_config( data_type: DataType, num_tables: int = 3, @@ -372,9 +333,13 @@ def _run_ebc_resharding_test( ): return - kjt_input_per_rank = generate_input_by_world_size( - world_size, num_tables, num_embeddings - ) + kjt_input_per_rank = [ + ModelInput.create_standard_kjt( + batch_size=2, + tables=embedding_bag_config, + ) + for _ in range(world_size) + ] # initial_state_dict filled with deterministic dummy values initial_state_dict = create_test_initial_state_dict( @@ -418,8 +383,8 @@ def test_dynamic_sharding_ebc_tw( old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] - if new_ranks == old_ranks: - return + while new_ranks == old_ranks: + new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] per_param_sharding = {} new_per_param_sharding = {}