From cc3029ffc6a321c77d7118c3e5a1dbb2378f55be Mon Sep 17 00:00:00 2001 From: Felicity Liao <11263993+aporialiao@users.noreply.github.com> Date: Wed, 23 Apr 2025 13:38:00 -0700 Subject: [PATCH] Enable Optimizer Storing & Fix incomplete updates to Sharded EBC attributes in resharding (#2911) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2911 Previously the dynamic sharding unit test was incomplete in truly verifying that a resharded EBC has all the attributes updated correctly. I ran into these issues when trying to enable optimizer state storing and DMP interface in D73049934 Main changes: 1. Add in dynamic sharding unit test's `are_sharded_ebc_modules_identical` the private attributes for ShardedEmbeddingCollection. This method will only compare primitive types or primitive reference types and tensors 1. This helped identify the gaps in current DS implementation - namely `module_sharding_plan`, `_embedding_dims`, `_uncombined_embedding_names`, `_uncombined_embedding_dims` not being updated correctly to reflect the new shard placements & order 2. Add in updates to `module_sharding_plan`, `_embedding_dims`, `_uncombined_embedding_names`, `_uncombined_embedding_dims` in reshard API for Sharded EBC. 3. Add in call to update Optimizer. The diff splits are not ideal, but the full optimizer unit test will be added in D73049934 Differential Revision: D73530909 --- torchrec/distributed/embeddingbag.py | 29 ++++++++++ .../distributed/sharding/dynamic_sharding.py | 14 +++++ .../tests/test_dynamic_sharding.py | 57 ++++++++++++++++--- 3 files changed, 93 insertions(+), 7 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 4cb1d62c2..d7ae684e3 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -55,6 +55,7 @@ from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding from torchrec.distributed.sharding.dynamic_sharding import ( shards_all_to_all, + update_module_sharding_plan, update_state_dict_post_resharding, ) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding @@ -1232,11 +1233,19 @@ def _update_output_dist(self) -> None: # TODO: Optimize to only go through embedding shardings with new ranks self._output_dists: List[nn.Module] = [] self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._uncombined_embedding_names: List[str] = [] + self._uncombined_embedding_dims: List[int] = [] for sharding in self._embedding_shardings: # TODO: if sharding type of table completely changes, need to regenerate everything self._embedding_names.extend(sharding.embedding_names()) self._output_dists.append(sharding.create_output_dist(device=self._device)) embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + self._embedding_dims.extend(sharding.embedding_dims()) + self._uncombined_embedding_names.extend( + sharding.uncombined_embedding_names() + ) + self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) embedding_shard_offsets: List[int] = [ meta.shard_offsets[1] if meta is not None else 0 @@ -1585,6 +1594,26 @@ def update_shards( self._initialize_torch_state(skip_registering=True) self.load_state_dict(current_state) + + # update optimizer + optims = [] + for lookup in self._lookups: + for _, tbe_module in lookup.named_modules(): + if isinstance(tbe_module, FusedOptimizerModule): + # modify param keys to match EmbeddingBagCollection + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for ( + param_key, + weight, + ) in tbe_module.fused_optimizer.params.items(): + # pyre-fixme[16]: `Mapping` has no attribute `__setitem__` + params["embedding_bags." + param_key] = weight + tbe_module.fused_optimizer.params = params + optims.append(("", tbe_module.fused_optimizer)) + + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + update_module_sharding_plan(self, changed_sharding_params) return @property diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index 05ca485f2..caa937db2 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -221,3 +221,17 @@ def update_state_dict_post_resharding( sharded_t._local_shards = [] return state_dict + + +def update_module_sharding_plan( + module: ShardedModule[Any, Any, Any, Any], # pyre-ignore + changed_sharding_params: Dict[str, ParameterSharding], +) -> None: + if not hasattr(module, "module_sharding_plan"): + return + + # pyre-ignore + current_plan: Dict[str, ParameterSharding] = module.module_sharding_plan + for table_name, param_sharding in changed_sharding_params.items(): + current_plan[table_name] = param_sharding + return diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 63da24ba5..f9a07fc50 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -141,13 +141,10 @@ def create_test_initial_state_dict( return initial_state_dict -def are_modules_identical( - module1: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], - module2: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], +def are_sharded_ebc_modules_identical( + module1: ShardedEmbeddingBagCollection, + module2: ShardedEmbeddingBagCollection, ) -> None: - # Check if both modules have the same type - assert type(module1) is type(module2) - # Check if both modules have the same parameters params1 = list(module1.named_parameters()) params2 = list(module2.named_parameters()) @@ -170,6 +167,52 @@ def are_modules_identical( assert buffer1[0] == buffer2[0] # Check buffer names assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values + # Hard-coded attributes for EmbeddingBagCollection + attribute_list = [ + "_module_fqn", + "_table_names", + "_pooling_type_to_rs_features", + "_output_dtensor", + "_sharding_types", + "_is_weighted", + "_embedding_names", + "_embedding_dims", + "_feature_splits", + "_features_order", + "_uncombined_embedding_names", + "_uncombined_embedding_dims", + "_has_mean_pooling_callback", + "_kjt_key_indices", + "_has_uninitialized_input_dist", + "_has_features_permute", + "_dim_per_key", # Tensor + "_inverse_indices_permute_indices", # Tensor + "_kjt_inverse_order", # Tensor + "_kt_key_ordering", # Tensor + # Non-primitive types which can be compared + "module_sharding_plan", + "_table_name_to_config", + # Excluding the non-primitive types that cannot be compared + # "sharding_type_to_sharding_infos", + # "_embedding_shardings" + # "_input_dists", + # "_lookups", + # "_output_dists", + # "_optim", + ] + + for attr in attribute_list: + assert hasattr(module1, attr) and hasattr(module2, attr) + + val1 = getattr(module1, attr) + val2 = getattr(module2, attr) + + assert type(val1) is type(val2) + if type(val1) is torch.Tensor: + torch.testing.assert_close(val1, val2) + else: + assert val1 == val2 + def output_sharding_plan_delta( old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan @@ -274,7 +317,7 @@ def _test_ebc_resharding( device=ctx.device, ) - are_modules_identical(sharded_m1, resharded_m2) + are_sharded_ebc_modules_identical(sharded_m1, resharded_m2) feature_keys = [] for table in tables: