From e8543a168f8f27953f658699c2a682179d5f62c5 Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Mon, 14 Nov 2022 16:52:04 -0800 Subject: [PATCH] update ShardedEmbeddingBagCollection to be use registered EBCs with shardedTensors as registered modules (#88026) Summary: X-link: https://github.com/pytorch/pytorch/pull/88026 Pull Request resolved: https://github.com/pytorch/torchrec/pull/758 update ShardedEmbeddingBagCollection to be composable according to https://docs.google.com/document/d/1TBJSd5zgEg6cRcXv3Okuj7bBkqQwGS2IPh4TLWNNzFI/edit this works with DMP named_parameters() behavior changes -> use include_fused as temporary flag to gate this behavior note that due to ShardedTensor not supporting grads directly, this won't work for Dense compute kernels when non data parallel. This is not used today, and will add a TODO but is low pri Differential Revision: D40458625 fbshipit-source-id: 9135216ac67c828d8532d5c251cd6b8d170c058b --- torchrec/distributed/composable/__init__.py | 0 .../distributed/composable/tests/__init__.py | 0 .../composable/tests/test_embeddingbag.py | 334 ++++++++++++++++++ torchrec/distributed/embeddingbag.py | 245 +++++++++---- torchrec/distributed/fused_embeddingbag.py | 1 + torchrec/distributed/model_parallel.py | 65 +++- .../distributed/test_utils/test_sharding.py | 2 +- .../test_fused_embedding_bag_collection.py | 4 +- .../distributed/tests/test_model_parallel.py | 54 +-- .../distributed/tests/test_train_pipeline.py | 9 +- torchrec/distributed/types.py | 7 + torchrec/test_utils/__init__.py | 46 ++- 12 files changed, 625 insertions(+), 142 deletions(-) create mode 100644 torchrec/distributed/composable/__init__.py create mode 100644 torchrec/distributed/composable/tests/__init__.py create mode 100644 torchrec/distributed/composable/tests/test_embeddingbag.py diff --git a/torchrec/distributed/composable/__init__.py b/torchrec/distributed/composable/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/composable/tests/__init__.py b/torchrec/distributed/composable/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/composable/tests/test_embeddingbag.py b/torchrec/distributed/composable/tests/test_embeddingbag.py new file mode 100644 index 000000000..69f5e9dea --- /dev/null +++ b/torchrec/distributed/composable/tests/test_embeddingbag.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest +from typing import Any, Dict, List, Optional + +import hypothesis.strategies as st +import torch +import torch.nn as nn +from hypothesis import assume, given, settings, Verbosity +from torchrec import distributed as trec_dist +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) + +from torchrec.distributed.shard import shard +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict +from torchrec.distributed.types import ( + ModuleSharder, + QuantizedCommCodecs, + ShardingEnv, + ShardingPlan, + ShardingType, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.test_utils import ( + assert_state_buffers_parameters_equal, + skip_if_asan_class, +) + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + local_size: Optional[int] = None, + is_data_parallel: bool = False, + use_apply_optimizer_in_backward: bool = False, +) -> None: + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size, ctx.device.type, local_world_size=ctx.local_size + ), + constraints=constraints, + ) + model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + unsharded_model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + if use_apply_optimizer_in_backward: + apply_optimizer_in_backward( + torch.optim.SGD, + model.embedding_bags["table_0"].parameters(), + {"lr": 1.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + model.embedding_bags["table_1"].parameters(), + {"lr": 4.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embedding_bags["table_0"].parameters(), + {"lr": 1.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embedding_bags["table_1"].parameters(), + {"lr": 4.0}, + ) + plan: ShardingPlan = planner.collective_plan(model, [sharder], ctx.pg) + sharded_model, _ = shard( + module=model, + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=[sharder], + device=ctx.device, + ) + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer = torch.optim.SGD( + unsharded_model.parameters(), lr=0.01 + ) + sharded_model_optimizer = torch.optim.SGD( + sharded_model.parameters(), lr=0.01 + ) + + assert isinstance(sharded_model, ShardedEmbeddingBagCollection) + + unsharded_model.load_state_dict(copy.deepcopy(initial_state_dict)) + copy_state_dict(sharded_model.state_dict(), copy.deepcopy(initial_state_dict)) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + for _it in range(5): + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.zero_grad() + sharded_model_optimizer.zero_grad() + + unsharded_model_pred_kt = [] + for unsharded_rank in range(ctx.world_size): + # simulate the unsharded model run on the entire batch + unsharded_model_pred_kt.append( + unsharded_model(kjt_input_per_rank[unsharded_rank]) + ) + + all_unsharded_preds = [] + for unsharded_rank in range(ctx.world_size): + unsharded_model_pred_kt_mini_batch = unsharded_model_pred_kt[ + unsharded_rank + ].to_dict() + + all_unsharded_preds.extend( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + if unsharded_rank == ctx.rank: + unsharded_model_pred = torch.stack( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + # sharded model + # each rank gets a subbatch + sharded_model_pred_kt = sharded_model( + kjt_input_per_rank[ctx.rank] + ).to_dict() + sharded_model_pred = torch.stack( + [sharded_model_pred_kt[feature] for feature in feature_keys] + ) + + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close( + sharded_model_pred.cpu(), unsharded_model_pred.cpu() + ) + + sharded_model_pred.sum().backward() + + all_unsharded_preds = torch.stack(all_unsharded_preds) + _sum = all_unsharded_preds.sum() + if is_data_parallel: + _sum /= world_size + _sum.backward() + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.step() + sharded_model_optimizer.step() + + # check nn.Module APIs look the same + assert_state_buffers_parameters_equal(unsharded_model, sharded_model) + + for fqn in unsharded_model.state_dict(): + unsharded_state = unsharded_model.state_dict()[fqn] + sharded_state = sharded_model.state_dict()[fqn] + + if is_data_parallel: + continue + else: + out = ( + torch.zeros(size=unsharded_state.shape, device=ctx.device) + if ctx.rank == 0 + else None + ) + sharded_state.gather(out=out) + if ctx.rank == 0: + torch.testing.assert_close( + unsharded_state, + out, + ) + + +class TestEmbeddingBagCollectionSharder(EmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._sharding_type = sharding_type + + """ + Restricts sharding to single type only. + """ + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +@skip_if_asan_class +class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + ), + use_apply_optimizer_in_backward=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_ebc( + self, + sharding_type: str, + use_apply_optimizer_in_backward: bool, + ) -> None: + + # TODO DistributedDataParallel needs full support of registering fused optims before we can enable this. + assume( + not ( + use_apply_optimizer_in_backward + and sharding_type == ShardingType.DATA_PARALLEL.value + ), + ) + + WORLD_SIZE = 2 + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=4, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=4, + num_embeddings=4, + ), + ] + + # Rank 0 + # instance 0 instance 1 instance 2 + # "feature_0" [0, 1] None [2] + # "feature_1" [0, 1] None [2] + + # Rank 1 + + # instance 0 instance 1 instance 2 + # "feature_0" [3, 2] [1,2] [0,1,2,3] + # "feature_1" [2, 3] None [2] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([0, 1, 2, 0, 1, 2]), + lengths=torch.LongTensor([2, 0, 1, 2, 0, 1]), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2]), + lengths=torch.LongTensor([2, 2, 4, 2, 0, 1]), + ), + ] + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + initial_state_dict={ + "embedding_bags.table_0.weight": torch.Tensor( + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [4, 4, 4, 4], + [8, 8, 8, 8], + ] + ), + "embedding_bags.table_1.weight": torch.Tensor( + [ + [101, 101, 101, 101], + [102, 102, 102, 102], + [104, 104, 104, 104], + [108, 108, 108, 108], + ] + ), + }, + kjt_input_per_rank=kjt_input_per_rank, + sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type), + backend="nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo", + is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value), + use_apply_optimizer_in_backward=use_apply_optimizer_in_backward, + ) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 21973450e..50c5c1339 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -8,11 +8,24 @@ import copy from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) import torch from torch import nn, Tensor from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -24,6 +37,7 @@ from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, EmbeddingComputeKernel, + GroupedEmbeddingConfig, SparseFeatures, SparseFeaturesList, ) @@ -47,11 +61,14 @@ ) from torchrec.distributed.utils import ( append_prefix, - filter_state_dict, merge_fused_params, optimizer_type_to_emb_opt_type, ) -from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingTableConfig, + PoolingType, +) from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, EmbeddingBagCollectionInterface, @@ -289,9 +306,9 @@ class ShardedEmbeddingBagCollection( KeyedTensor, EmbeddingBagCollectionContext, ], + # TODO remove after compute_kernel X sharding decoupling FusedOptimizerModule, ): - # TODO remove after compute_kernel X sharding decoupling """ Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining. @@ -308,6 +325,13 @@ def __init__( variable_batch_size: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._embedding_bag_configs: List[ + EmbeddingBagConfig + ] = module.embedding_bag_configs() + + self._table_name_to_parameter_sharding = table_name_to_parameter_sharding + self._env = env + sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( module, table_name_to_parameter_sharding, @@ -339,10 +363,10 @@ def __init__( self._is_weighted: bool = module.is_weighted() self._device = device - self._input_dists = nn.ModuleList() - self._lookups: nn.ModuleList = nn.ModuleList() + self._input_dists: List[nn.Module] = [] + self._lookups: List[nn.Module] = [] self._create_lookups() - self._output_dists: nn.ModuleList = nn.ModuleList() + self._output_dists: List[nn.Module] = [] self._embedding_names: List[str] = [] self._embedding_dims: List[int] = [] self._feature_splits: List[int] = [] @@ -353,6 +377,7 @@ def __init__( # forward pass flow control self._has_uninitialized_input_dist: bool = True self._has_features_permute: bool = True + # Get all fused optimizers and combine them. optims = [] for lookup in self._lookups: @@ -367,6 +392,122 @@ def __init__( optims.append(("", module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) + for index, (sharding, lookup) in enumerate( + zip( + self._sharding_type_to_sharding.values(), + self._lookups, + ) + ): + if isinstance(sharding, DpPooledEmbeddingSharding): + self._lookups[index] = DistributedDataParallel( + module=lookup, + device_ids=[device] + if self._device and self._device.type == "gpu" + else None, + process_group=env.process_group, + # TODO investigate perf drop here + gradient_as_bucket_view=False, + broadcast_buffers=True, + static_graph=True, + ) + + self._initialize_torch_state() + + def _initialize_torch_state(self) -> None: # noqa + """ + This provides consistency between this class and the EmbeddingBagCollection's + nn.Module API calls (state_dict, named_modules, etc) + """ + + def hook_wrapper( + embedding_module: nn.Module, + config: GroupedEmbeddingConfig, + param_weights: torch.Tensor, + ) -> Callable[[Optional[torch.Tensor]], None]: + # pyre-ignore + def assign_param_grad_as_view_hook(*_args) -> None: + grad = param_weights.grad + assert grad is not None + for t_idx, (rows, dim) in enumerate(emb_module.embedding_specs): + table_name = config.embedding_tables[t_idx].name + offset = emb_module.weights_physical_offsets[t_idx] + # TODO move this logic to FBGEMM + this_grad = grad[offset : offset + rows * dim].view(rows, dim) + self.embedding_bags[table_name].weight.grad = this_grad + self._hooks["".join(config.feature_names())].remove() + + return assign_param_grad_as_view_hook + + self.embedding_bags: nn.ModuleDict = nn.ModuleDict() + self._hooks = {} + + model_parallel_name_to_local_shards = OrderedDict() + for ( + table_name, + parameter_sharding, + ) in self._table_name_to_parameter_sharding.items(): + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + continue + model_parallel_name_to_local_shards[table_name] = [] + + name_to_table_size = {} + for table in self._embedding_bag_configs: + name_to_table_size[table.name] = (table.num_embeddings, table.embedding_dim) + + for sharding_type, lookup in zip( + self._sharding_type_to_sharding.keys(), self._lookups + ): + # TODO support dense kernels in model parallel sharding + if sharding_type == ShardingType.DATA_PARALLEL.value: + lookup = lookup.module + if self._is_weighted: + embeddings = lookup._score_emb_modules + configs = lookup.grouped_score_configs + else: + embeddings = lookup._emb_modules + configs = lookup.grouped_configs + + for config, embedding in zip(configs, embeddings): + emb_module = embedding._emb_module + for t_idx, weight in enumerate( + emb_module.split_embedding_weights() + ): + table_name = config.embedding_tables[t_idx].name + self.embedding_bags[table_name] = torch.nn.Module() + self.embedding_bags[table_name].register_parameter( + "weight", torch.nn.Parameter(weight) + ) + param_weights = dict(emb_module.named_parameters())["weights"] + param_weights.acc_grad = param_weights.view_as( + param_weights + ).grad_fn.next_functions[0][0] + + # Hooks are used to configure (per embedding lookup) post grad hooks that will + # create views of grad and remove() themselves after first call + self._hooks[ + "".join(config.feature_names()) + ] = param_weights.acc_grad.register_hook( + hook_wrapper(emb_module, config, param_weights) + ) + else: + lookup_state_dict = lookup.state_dict() + for key, v in lookup_state_dict.items(): + table_name = key[: -len(".weight")] + model_parallel_name_to_local_shards[table_name].extend( + v.local_shards() + ) + + for table_name, local_shards in model_parallel_name_to_local_shards.items(): + weight = ShardedTensor._init_from_local_shards( + local_shards, + name_to_table_size[table_name], + process_group=self._env.process_group, + ) + self.embedding_bags[table_name] = torch.nn.Module() + self.embedding_bags[table_name].register_parameter( + "weight", torch.nn.Parameter(weight) + ) + def _create_input_dist( self, input_feature_names: List[str], @@ -507,75 +648,34 @@ def compute_and_output_dist( embedding_names=self._embedding_names, ) - # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - for lookup in self._lookups: - lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars) - return destination - - def named_modules( + def named_parameters( self, - memo: Optional[Set[nn.Module]] = None, prefix: str = "", + recurse: bool = True, remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, nn.Module]]: - yield from [(prefix, self)] - - def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - for lookup in self._lookups: - yield from lookup.named_parameters( - append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate - ) - - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for lookup, sharding_type in zip( - self._lookups, self._sharding_type_to_sharding.keys() + # TODO remove when note needed + include_fused: bool = True, + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Args: + prefix (str): + recurse (bool): + remove_duplicate (bool): + include_fused (bool): flag for whether or not to include fused parameters. set to False for backward compatibility (of not returning fused) + """ + for name, param in super().named_parameters( + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - continue - for name, _ in lookup.named_parameters( - append_prefix(prefix, "embedding_bags") - ): - yield name - - def named_buffers( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for lookup in self._lookups: - yield from lookup.named_buffers( - append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate - ) - - # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` - # inconsistently. - def load_state_dict( - self, - state_dict: "OrderedDict[str, torch.Tensor]", - strict: bool = True, - ) -> _IncompatibleKeys: - missing_keys = [] - unexpected_keys = [] - for lookup in self._lookups: - missing, unexpected = lookup.load_state_dict( - filter_state_dict(state_dict, "embedding_bags"), - strict, - ) - missing_keys.extend(missing) - unexpected_keys.extend(unexpected) - return _IncompatibleKeys( - missing_keys=missing_keys, unexpected_keys=unexpected_keys - ) + is_fused = False + if "embedding_bags" in name: + pos = name.find("embedding_bags") + table_name = name[ + pos + len("embedding_bags") + 1 : -(len("weight") + 1) + ] + sharding = self._table_name_to_parameter_sharding[table_name] + is_fused = sharding.compute_kernel == "fused" + if include_fused or not is_fused: + yield name, param def sparse_grad_parameter_names( self, @@ -584,6 +684,7 @@ def sparse_grad_parameter_names( ) -> List[str]: destination = [] if destination is None else destination for lookup in self._lookups: + # pyre-ignore lookup.sparse_grad_parameter_names( destination, append_prefix(prefix, "embedding_bags") ) diff --git a/torchrec/distributed/fused_embeddingbag.py b/torchrec/distributed/fused_embeddingbag.py index 31a4c31b6..345e31ed6 100644 --- a/torchrec/distributed/fused_embeddingbag.py +++ b/torchrec/distributed/fused_embeddingbag.py @@ -75,6 +75,7 @@ def __init__( broadcast_buffers=False, static_graph=True, ) + # pyre-ignore self._lookups[index]._register_fused_optim( optimizer_type, **optimizer_kwargs ) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 359a460e5..d45dc2c24 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -16,16 +16,17 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.planner import ( EmbeddingShardingPlanner, sharder_name, Topology, ) - from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.types import ( ModuleSharder, ShardedModule, + ShardedTensor, ShardingEnv, ShardingPlan, ) @@ -75,28 +76,24 @@ def wrap( pg = env.process_group if pg is None: raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv") - sharded_parameter_names = { - key - for key in DistributedModelParallel._sharded_parameter_names( - dmp._dmp_wrapped_module - ) + sharded_parameter_names = set( + DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module) + ) + all_parameter_names = { + key for key, _ in dmp.named_parameters(include_fused=True) } - all_paramemeter_names = {key for key, _ in dmp.named_parameters()} - if sharded_parameter_names == all_paramemeter_names: + if len(all_parameter_names - sharded_parameter_names) == 0: return - DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( module=dmp._dmp_wrapped_module, - params_and_buffers_to_ignore=[ - key for key in all_paramemeter_names if key in sharded_parameter_names - ], + params_and_buffers_to_ignore=sharded_parameter_names, ) # initialize DDP dmp._dmp_wrapped_module = cast( nn.Module, DistributedDataParallel( module=dmp._dmp_wrapped_module.to(device), - device_ids=None if device.type == "cpu" else [device], + device_ids=[device] if device.type == "gpu" else None, process_group=pg, gradient_as_bucket_view=True, broadcast_buffers=False, @@ -185,6 +182,8 @@ def __init__( torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") self.init_parameters = init_parameters + + self._dmp_wrapped_module = module self._ddp_wrapped: bool = False if env is None: @@ -341,7 +340,11 @@ def init_parameters(module: nn.Module) -> None: # Allocate parameters and buffers if over 'meta' device. has_meta_param = False for name, param in module._parameters.items(): - if isinstance(param, torch.Tensor) and param.device.type == "meta": + if ( + isinstance(param, torch.Tensor) + and not isinstance(param, ShardedTensor) + and param.device.type == "meta" + ): module._parameters[name] = nn.Parameter( torch.empty_like(param, device=self.device), requires_grad=param.requires_grad, @@ -449,22 +452,48 @@ def _named_parameters( prefix: str = "", recurse: bool = True, strip_ddp: bool = True, + include_fused: bool = False, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: if strip_ddp: module = get_unwrapped_module(module) if isinstance(module, ShardedModule): - yield from module.named_parameters(prefix, recurse) + if isinstance(module, ShardedEmbeddingBagCollection): + yield from module.named_parameters( + prefix, recurse, include_fused=include_fused + ) + else: + yield from module.named_parameters( + prefix, + recurse, + ) else: yield from module.named_parameters(prefix, recurse=False) for name, child in module.named_children(): yield from self._named_parameters( - child, append_prefix(prefix, name), recurse, strip_ddp + child, + append_prefix(prefix, name), + recurse, + strip_ddp, + include_fused=include_fused, ) def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + self, + prefix: str = "", + recurse: bool = True, + remove_duplicate: bool = True, + # TODO remove when note needed + include_fused: bool = False, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: - gen = self._named_parameters(self.module, prefix, recurse) + """ + Args: + prefix (str): + recurse (bool): + include_fused (bool): flag for whether or not to include fused parameters. set to False for backward compatibility (of not returning fused) + """ + gen = self._named_parameters( + self.module, prefix, recurse, include_fused=include_fused + ) memo = set() for key, param in gen: if param in memo: diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 61e031d57..fb836d7ed 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -43,6 +43,7 @@ from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.test_utils import assert_state_buffers_parameters_equal from typing_extensions import Protocol @@ -337,7 +338,6 @@ def sharding_single_rank_test( sharders=sharders, device=ctx.device, ) - dense_optim = KeyedOptimizerWrapper( dict(local_model.named_parameters()), lambda params: torch.optim.SGD(params, lr=0.1), diff --git a/torchrec/distributed/tests/test_fused_embedding_bag_collection.py b/torchrec/distributed/tests/test_fused_embedding_bag_collection.py index 0905c6638..a6716fef3 100644 --- a/torchrec/distributed/tests/test_fused_embedding_bag_collection.py +++ b/torchrec/distributed/tests/test_fused_embedding_bag_collection.py @@ -111,7 +111,7 @@ class FusedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_fused_ebc( self, sharder_type: str, @@ -168,7 +168,7 @@ def test_sharding_fused_ebc( ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_fused_ebc_module_replace( self, sharding_type: str, diff --git a/torchrec/distributed/tests/test_model_parallel.py b/torchrec/distributed/tests/test_model_parallel.py index b2e77a3e0..e290ceda6 100644 --- a/torchrec/distributed/tests/test_model_parallel.py +++ b/torchrec/distributed/tests/test_model_parallel.py @@ -75,7 +75,7 @@ class ModelParallelTest(ModelParallelTestShared): ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -278,7 +278,7 @@ def test_sharding_nccl_cw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -352,7 +352,7 @@ def test_sharding_nccl_tw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -418,7 +418,7 @@ def test_sharding_gloo_tw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -510,6 +510,9 @@ def test_sharding_gloo_dp( class ModelParallelSparseOnlyTest(unittest.TestCase): + def tearDown(self) -> None: + dist.destroy_process_group() + def test_sharding_ebc_as_top_level(self) -> None: os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" @@ -545,46 +548,6 @@ def test_sharding_ebc_as_top_level(self) -> None: model = DistributedModelParallel(ebc, device=curr_device) self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection)) - dist.destroy_process_group() - - def test_sharding_fused_ebc_as_top_level(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - if torch.cuda.is_available(): - curr_device = torch.device("cuda:0") - torch.cuda.set_device(curr_device) - backend = "nccl" - else: - curr_device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) - - embedding_dim = 128 - num_embeddings = 256 - ebc = FusedEmbeddingBagCollection( - device=torch.device("meta"), - tables=[ - EmbeddingBagConfig( - name="large_table", - embedding_dim=embedding_dim, - num_embeddings=num_embeddings, - feature_names=["my_feature"], - pooling=PoolingType.SUM, - ), - ], - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - ) - - model = DistributedModelParallel(ebc, device=curr_device) - - self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection)) - dist.destroy_process_group() class ModelParallelStateDictTest(unittest.TestCase): @@ -900,9 +863,8 @@ def test_params_and_buffers( ] # pyre-ignore[6] (m, _), batch = self._generate_dmps_and_batch(sharders=sharders) - print(f"Sharding Plan: {m._plan}") state_dict_keys = set(m.state_dict().keys()) - param_keys = {key for (key, _) in m.named_parameters()} + param_keys = {key for (key, _) in m.named_parameters(include_fused=True)} buffer_keys = {key for (key, _) in m.named_buffers()} self.assertEqual(state_dict_keys, {*param_keys, *buffer_keys}) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index 2f15d2114..c668c70ba 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -78,7 +78,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]: def compute_kernels( self, sharding_type: str, compute_device_type: str ) -> List[str]: - return [EmbeddingComputeKernel.DENSE.value] + return [EmbeddingComputeKernel.FUSED.value] @dataclass @@ -270,7 +270,12 @@ def test_position_weighted_feature_processor(self) -> None: env=ShardingEnv.from_process_group(self.pg), init_data_parallel=True, device=self.device, - sharders=[cast(ModuleSharder[nn.Module], TestCustomEBCSharder())], + sharders=[ + cast( + ModuleSharder[nn.Module], + TestCustomEBCSharder(fused_params={"learning_rate": 0.1}), + ) + ], ) test_unsharded_model = TestSparseNN( tables=self.tables + fp_tables, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index c71260f34..e132ecfd2 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -662,6 +662,13 @@ def loop(key: str, modules: List[nn.Module]) -> List[str]: return "\n ".join(rep) + def _initialize_torch_state(self) -> None: + """ + This provides consistency between this class and the ShardedModule's + nn.Module API calls (state_dict, named_modules, etc) + """ + pass + class ModuleSharder(abc.ABC, Generic[M]): """ diff --git a/torchrec/test_utils/__init__.py b/torchrec/test_utils/__init__.py index ecc799d5b..c9d0d1a32 100644 --- a/torchrec/test_utils/__init__.py +++ b/torchrec/test_utils/__init__.py @@ -12,12 +12,13 @@ import time from contextlib import closing from functools import wraps -from typing import Callable, Optional, TypeVar +from typing import Any, Callable, Dict, Optional, TypeVar import numpy as np import torch import torch.distributed as dist from pyre_extensions import ParameterSpecification +from torch import nn TParams = ParameterSpecification("TParams") TReturn = TypeVar("TReturn") @@ -95,3 +96,46 @@ def _wrapper(*args, **kwargs): return wrapped_func(*args, **kwargs) return _wrapper + + +def get_state_buffers_parameters(model: nn.Module) -> Dict[str, Any]: + return { + "state_dict": model.state_dict(), + "named_buffers": dict(model.named_buffers()), + "named_parameters": dict(model.named_parameters()), + } + + +def assert_state_buffers_parameters_equal( + model_1: nn.Module, + model_2: nn.Module, + check_named_buffers: bool = True, + check_named_parameters: bool = True, + check_state_dict: bool = True, +) -> None: + """ + Checks to see if the keys of top level PyTorch API calls are the same + between two modules. + """ + + model_characteristics = {} + model_characteristics["model_1"] = get_state_buffers_parameters(model_1) + model_characteristics["model_2"] = get_state_buffers_parameters(model_2) + + assert ( + not check_named_buffers + or model_characteristics["model_1"]["named_buffers"].keys() + == model_characteristics["model_2"]["named_buffers"].keys() + ), "named buffers keys are not the same" + + assert ( + not check_named_parameters + or model_characteristics["model_1"]["named_parameters"].keys() + == model_characteristics["model_2"]["named_parameters"].keys() + ), "named parameter keys are not the same" + + assert ( + not check_state_dict + or model_characteristics["model_1"]["state_dict"].keys() + == model_characteristics["model_2"]["state_dict"].keys() + ), f"state dict key are not the same, {model_characteristics['model_1']['state_dict'].keys()} vs {model_characteristics['model_2']['state_dict'].keys()}"