diff --git a/torchrec/distributed/composable/__init__.py b/torchrec/distributed/composable/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/composable/tests/__init__.py b/torchrec/distributed/composable/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchrec/distributed/composable/tests/test_embeddingbag.py b/torchrec/distributed/composable/tests/test_embeddingbag.py new file mode 100644 index 000000000..69f5e9dea --- /dev/null +++ b/torchrec/distributed/composable/tests/test_embeddingbag.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest +from typing import Any, Dict, List, Optional + +import hypothesis.strategies as st +import torch +import torch.nn as nn +from hypothesis import assume, given, settings, Verbosity +from torchrec import distributed as trec_dist +from torchrec.distributed.embeddingbag import ( + EmbeddingBagCollectionSharder, + ShardedEmbeddingBagCollection, +) +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) + +from torchrec.distributed.shard import shard +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict +from torchrec.distributed.types import ( + ModuleSharder, + QuantizedCommCodecs, + ShardingEnv, + ShardingPlan, + ShardingType, +) +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.modules.embedding_modules import EmbeddingBagCollection +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.test_utils import ( + assert_state_buffers_parameters_equal, + skip_if_asan_class, +) + + +def _test_sharding( # noqa C901 + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + sharder: ModuleSharder[nn.Module], + backend: str, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + local_size: Optional[int] = None, + is_data_parallel: bool = False, + use_apply_optimizer_in_backward: bool = False, +) -> None: + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size, ctx.device.type, local_world_size=ctx.local_size + ), + constraints=constraints, + ) + model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + unsharded_model = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + if use_apply_optimizer_in_backward: + apply_optimizer_in_backward( + torch.optim.SGD, + model.embedding_bags["table_0"].parameters(), + {"lr": 1.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + model.embedding_bags["table_1"].parameters(), + {"lr": 4.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embedding_bags["table_0"].parameters(), + {"lr": 1.0}, + ) + apply_optimizer_in_backward( + torch.optim.SGD, + unsharded_model.embedding_bags["table_1"].parameters(), + {"lr": 4.0}, + ) + plan: ShardingPlan = planner.collective_plan(model, [sharder], ctx.pg) + sharded_model, _ = shard( + module=model, + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=[sharder], + device=ctx.device, + ) + + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer = torch.optim.SGD( + unsharded_model.parameters(), lr=0.01 + ) + sharded_model_optimizer = torch.optim.SGD( + sharded_model.parameters(), lr=0.01 + ) + + assert isinstance(sharded_model, ShardedEmbeddingBagCollection) + + unsharded_model.load_state_dict(copy.deepcopy(initial_state_dict)) + copy_state_dict(sharded_model.state_dict(), copy.deepcopy(initial_state_dict)) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + for _it in range(5): + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.zero_grad() + sharded_model_optimizer.zero_grad() + + unsharded_model_pred_kt = [] + for unsharded_rank in range(ctx.world_size): + # simulate the unsharded model run on the entire batch + unsharded_model_pred_kt.append( + unsharded_model(kjt_input_per_rank[unsharded_rank]) + ) + + all_unsharded_preds = [] + for unsharded_rank in range(ctx.world_size): + unsharded_model_pred_kt_mini_batch = unsharded_model_pred_kt[ + unsharded_rank + ].to_dict() + + all_unsharded_preds.extend( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + if unsharded_rank == ctx.rank: + unsharded_model_pred = torch.stack( + [ + unsharded_model_pred_kt_mini_batch[feature] + for feature in feature_keys + ] + ) + # sharded model + # each rank gets a subbatch + sharded_model_pred_kt = sharded_model( + kjt_input_per_rank[ctx.rank] + ).to_dict() + sharded_model_pred = torch.stack( + [sharded_model_pred_kt[feature] for feature in feature_keys] + ) + + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close( + sharded_model_pred.cpu(), unsharded_model_pred.cpu() + ) + + sharded_model_pred.sum().backward() + + all_unsharded_preds = torch.stack(all_unsharded_preds) + _sum = all_unsharded_preds.sum() + if is_data_parallel: + _sum /= world_size + _sum.backward() + if not use_apply_optimizer_in_backward: + unsharded_model_optimizer.step() + sharded_model_optimizer.step() + + # check nn.Module APIs look the same + assert_state_buffers_parameters_equal(unsharded_model, sharded_model) + + for fqn in unsharded_model.state_dict(): + unsharded_state = unsharded_model.state_dict()[fqn] + sharded_state = sharded_model.state_dict()[fqn] + + if is_data_parallel: + continue + else: + out = ( + torch.zeros(size=unsharded_state.shape, device=ctx.device) + if ctx.rank == 0 + else None + ) + sharded_state.gather(out=out) + if ctx.rank == 0: + torch.testing.assert_close( + unsharded_state, + out, + ) + + +class TestEmbeddingBagCollectionSharder(EmbeddingBagCollectionSharder): + def __init__( + self, + sharding_type: str, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._sharding_type = sharding_type + + """ + Restricts sharding to single type only. + """ + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +@skip_if_asan_class +class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + # pyre-fixme[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.DATA_PARALLEL.value, + ] + ), + use_apply_optimizer_in_backward=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_sharding_ebc( + self, + sharding_type: str, + use_apply_optimizer_in_backward: bool, + ) -> None: + + # TODO DistributedDataParallel needs full support of registering fused optims before we can enable this. + assume( + not ( + use_apply_optimizer_in_backward + and sharding_type == ShardingType.DATA_PARALLEL.value + ), + ) + + WORLD_SIZE = 2 + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=4, + num_embeddings=4, + ), + EmbeddingBagConfig( + name="table_1", + feature_names=["feature_1"], + embedding_dim=4, + num_embeddings=4, + ), + ] + + # Rank 0 + # instance 0 instance 1 instance 2 + # "feature_0" [0, 1] None [2] + # "feature_1" [0, 1] None [2] + + # Rank 1 + + # instance 0 instance 1 instance 2 + # "feature_0" [3, 2] [1,2] [0,1,2,3] + # "feature_1" [2, 3] None [2] + + kjt_input_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([0, 1, 2, 0, 1, 2]), + lengths=torch.LongTensor([2, 0, 1, 2, 0, 1]), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1"], + values=torch.LongTensor([3, 2, 1, 2, 0, 1, 2, 3, 2, 3, 2]), + lengths=torch.LongTensor([2, 2, 4, 2, 0, 1]), + ), + ] + self._run_multi_process_test( + callable=_test_sharding, + world_size=WORLD_SIZE, + tables=embedding_bag_config, + initial_state_dict={ + "embedding_bags.table_0.weight": torch.Tensor( + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [4, 4, 4, 4], + [8, 8, 8, 8], + ] + ), + "embedding_bags.table_1.weight": torch.Tensor( + [ + [101, 101, 101, 101], + [102, 102, 102, 102], + [104, 104, 104, 104], + [108, 108, 108, 108], + ] + ), + }, + kjt_input_per_rank=kjt_input_per_rank, + sharder=TestEmbeddingBagCollectionSharder(sharding_type=sharding_type), + backend="nccl" + if (torch.cuda.is_available() and torch.cuda.device_count() >= 2) + else "gloo", + is_data_parallel=(sharding_type == ShardingType.DATA_PARALLEL.value), + use_apply_optimizer_in_backward=use_apply_optimizer_in_backward, + ) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 21973450e..50c5c1339 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -8,11 +8,24 @@ import copy from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) import torch from torch import nn, Tensor from torch.nn.modules.module import _IncompatibleKeys +from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -24,6 +37,7 @@ from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, EmbeddingComputeKernel, + GroupedEmbeddingConfig, SparseFeatures, SparseFeaturesList, ) @@ -47,11 +61,14 @@ ) from torchrec.distributed.utils import ( append_prefix, - filter_state_dict, merge_fused_params, optimizer_type_to_emb_opt_type, ) -from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType +from torchrec.modules.embedding_configs import ( + EmbeddingBagConfig, + EmbeddingTableConfig, + PoolingType, +) from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, EmbeddingBagCollectionInterface, @@ -289,9 +306,9 @@ class ShardedEmbeddingBagCollection( KeyedTensor, EmbeddingBagCollectionContext, ], + # TODO remove after compute_kernel X sharding decoupling FusedOptimizerModule, ): - # TODO remove after compute_kernel X sharding decoupling """ Sharded implementation of EmbeddingBagCollection. This is part of the public API to allow for manual data dist pipelining. @@ -308,6 +325,13 @@ def __init__( variable_batch_size: bool = False, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + self._embedding_bag_configs: List[ + EmbeddingBagConfig + ] = module.embedding_bag_configs() + + self._table_name_to_parameter_sharding = table_name_to_parameter_sharding + self._env = env + sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( module, table_name_to_parameter_sharding, @@ -339,10 +363,10 @@ def __init__( self._is_weighted: bool = module.is_weighted() self._device = device - self._input_dists = nn.ModuleList() - self._lookups: nn.ModuleList = nn.ModuleList() + self._input_dists: List[nn.Module] = [] + self._lookups: List[nn.Module] = [] self._create_lookups() - self._output_dists: nn.ModuleList = nn.ModuleList() + self._output_dists: List[nn.Module] = [] self._embedding_names: List[str] = [] self._embedding_dims: List[int] = [] self._feature_splits: List[int] = [] @@ -353,6 +377,7 @@ def __init__( # forward pass flow control self._has_uninitialized_input_dist: bool = True self._has_features_permute: bool = True + # Get all fused optimizers and combine them. optims = [] for lookup in self._lookups: @@ -367,6 +392,122 @@ def __init__( optims.append(("", module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) + for index, (sharding, lookup) in enumerate( + zip( + self._sharding_type_to_sharding.values(), + self._lookups, + ) + ): + if isinstance(sharding, DpPooledEmbeddingSharding): + self._lookups[index] = DistributedDataParallel( + module=lookup, + device_ids=[device] + if self._device and self._device.type == "gpu" + else None, + process_group=env.process_group, + # TODO investigate perf drop here + gradient_as_bucket_view=False, + broadcast_buffers=True, + static_graph=True, + ) + + self._initialize_torch_state() + + def _initialize_torch_state(self) -> None: # noqa + """ + This provides consistency between this class and the EmbeddingBagCollection's + nn.Module API calls (state_dict, named_modules, etc) + """ + + def hook_wrapper( + embedding_module: nn.Module, + config: GroupedEmbeddingConfig, + param_weights: torch.Tensor, + ) -> Callable[[Optional[torch.Tensor]], None]: + # pyre-ignore + def assign_param_grad_as_view_hook(*_args) -> None: + grad = param_weights.grad + assert grad is not None + for t_idx, (rows, dim) in enumerate(emb_module.embedding_specs): + table_name = config.embedding_tables[t_idx].name + offset = emb_module.weights_physical_offsets[t_idx] + # TODO move this logic to FBGEMM + this_grad = grad[offset : offset + rows * dim].view(rows, dim) + self.embedding_bags[table_name].weight.grad = this_grad + self._hooks["".join(config.feature_names())].remove() + + return assign_param_grad_as_view_hook + + self.embedding_bags: nn.ModuleDict = nn.ModuleDict() + self._hooks = {} + + model_parallel_name_to_local_shards = OrderedDict() + for ( + table_name, + parameter_sharding, + ) in self._table_name_to_parameter_sharding.items(): + if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: + continue + model_parallel_name_to_local_shards[table_name] = [] + + name_to_table_size = {} + for table in self._embedding_bag_configs: + name_to_table_size[table.name] = (table.num_embeddings, table.embedding_dim) + + for sharding_type, lookup in zip( + self._sharding_type_to_sharding.keys(), self._lookups + ): + # TODO support dense kernels in model parallel sharding + if sharding_type == ShardingType.DATA_PARALLEL.value: + lookup = lookup.module + if self._is_weighted: + embeddings = lookup._score_emb_modules + configs = lookup.grouped_score_configs + else: + embeddings = lookup._emb_modules + configs = lookup.grouped_configs + + for config, embedding in zip(configs, embeddings): + emb_module = embedding._emb_module + for t_idx, weight in enumerate( + emb_module.split_embedding_weights() + ): + table_name = config.embedding_tables[t_idx].name + self.embedding_bags[table_name] = torch.nn.Module() + self.embedding_bags[table_name].register_parameter( + "weight", torch.nn.Parameter(weight) + ) + param_weights = dict(emb_module.named_parameters())["weights"] + param_weights.acc_grad = param_weights.view_as( + param_weights + ).grad_fn.next_functions[0][0] + + # Hooks are used to configure (per embedding lookup) post grad hooks that will + # create views of grad and remove() themselves after first call + self._hooks[ + "".join(config.feature_names()) + ] = param_weights.acc_grad.register_hook( + hook_wrapper(emb_module, config, param_weights) + ) + else: + lookup_state_dict = lookup.state_dict() + for key, v in lookup_state_dict.items(): + table_name = key[: -len(".weight")] + model_parallel_name_to_local_shards[table_name].extend( + v.local_shards() + ) + + for table_name, local_shards in model_parallel_name_to_local_shards.items(): + weight = ShardedTensor._init_from_local_shards( + local_shards, + name_to_table_size[table_name], + process_group=self._env.process_group, + ) + self.embedding_bags[table_name] = torch.nn.Module() + self.embedding_bags[table_name].register_parameter( + "weight", torch.nn.Parameter(weight) + ) + def _create_input_dist( self, input_feature_names: List[str], @@ -507,75 +648,34 @@ def compute_and_output_dist( embedding_names=self._embedding_names, ) - # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently. - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - for lookup in self._lookups: - lookup.state_dict(destination, prefix + "embedding_bags.", keep_vars) - return destination - - def named_modules( + def named_parameters( self, - memo: Optional[Set[nn.Module]] = None, prefix: str = "", + recurse: bool = True, remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, nn.Module]]: - yield from [(prefix, self)] - - def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, nn.Parameter]]: - for lookup in self._lookups: - yield from lookup.named_parameters( - append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate - ) - - def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: - for lookup, sharding_type in zip( - self._lookups, self._sharding_type_to_sharding.keys() + # TODO remove when note needed + include_fused: bool = True, + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Args: + prefix (str): + recurse (bool): + remove_duplicate (bool): + include_fused (bool): flag for whether or not to include fused parameters. set to False for backward compatibility (of not returning fused) + """ + for name, param in super().named_parameters( + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate ): - if sharding_type == ShardingType.DATA_PARALLEL.value: - continue - for name, _ in lookup.named_parameters( - append_prefix(prefix, "embedding_bags") - ): - yield name - - def named_buffers( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for lookup in self._lookups: - yield from lookup.named_buffers( - append_prefix(prefix, "embedding_bags"), recurse, remove_duplicate - ) - - # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` - # inconsistently. - def load_state_dict( - self, - state_dict: "OrderedDict[str, torch.Tensor]", - strict: bool = True, - ) -> _IncompatibleKeys: - missing_keys = [] - unexpected_keys = [] - for lookup in self._lookups: - missing, unexpected = lookup.load_state_dict( - filter_state_dict(state_dict, "embedding_bags"), - strict, - ) - missing_keys.extend(missing) - unexpected_keys.extend(unexpected) - return _IncompatibleKeys( - missing_keys=missing_keys, unexpected_keys=unexpected_keys - ) + is_fused = False + if "embedding_bags" in name: + pos = name.find("embedding_bags") + table_name = name[ + pos + len("embedding_bags") + 1 : -(len("weight") + 1) + ] + sharding = self._table_name_to_parameter_sharding[table_name] + is_fused = sharding.compute_kernel == "fused" + if include_fused or not is_fused: + yield name, param def sparse_grad_parameter_names( self, @@ -584,6 +684,7 @@ def sparse_grad_parameter_names( ) -> List[str]: destination = [] if destination is None else destination for lookup in self._lookups: + # pyre-ignore lookup.sparse_grad_parameter_names( destination, append_prefix(prefix, "embedding_bags") ) diff --git a/torchrec/distributed/fused_embeddingbag.py b/torchrec/distributed/fused_embeddingbag.py index 31a4c31b6..345e31ed6 100644 --- a/torchrec/distributed/fused_embeddingbag.py +++ b/torchrec/distributed/fused_embeddingbag.py @@ -75,6 +75,7 @@ def __init__( broadcast_buffers=False, static_graph=True, ) + # pyre-ignore self._lookups[index]._register_fused_optim( optimizer_type, **optimizer_kwargs ) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 359a460e5..d45dc2c24 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -16,16 +16,17 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.planner import ( EmbeddingShardingPlanner, sharder_name, Topology, ) - from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.types import ( ModuleSharder, ShardedModule, + ShardedTensor, ShardingEnv, ShardingPlan, ) @@ -75,28 +76,24 @@ def wrap( pg = env.process_group if pg is None: raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv") - sharded_parameter_names = { - key - for key in DistributedModelParallel._sharded_parameter_names( - dmp._dmp_wrapped_module - ) + sharded_parameter_names = set( + DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module) + ) + all_parameter_names = { + key for key, _ in dmp.named_parameters(include_fused=True) } - all_paramemeter_names = {key for key, _ in dmp.named_parameters()} - if sharded_parameter_names == all_paramemeter_names: + if len(all_parameter_names - sharded_parameter_names) == 0: return - DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( module=dmp._dmp_wrapped_module, - params_and_buffers_to_ignore=[ - key for key in all_paramemeter_names if key in sharded_parameter_names - ], + params_and_buffers_to_ignore=sharded_parameter_names, ) # initialize DDP dmp._dmp_wrapped_module = cast( nn.Module, DistributedDataParallel( module=dmp._dmp_wrapped_module.to(device), - device_ids=None if device.type == "cpu" else [device], + device_ids=[device] if device.type == "gpu" else None, process_group=pg, gradient_as_bucket_view=True, broadcast_buffers=False, @@ -185,6 +182,8 @@ def __init__( torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") self.init_parameters = init_parameters + + self._dmp_wrapped_module = module self._ddp_wrapped: bool = False if env is None: @@ -341,7 +340,11 @@ def init_parameters(module: nn.Module) -> None: # Allocate parameters and buffers if over 'meta' device. has_meta_param = False for name, param in module._parameters.items(): - if isinstance(param, torch.Tensor) and param.device.type == "meta": + if ( + isinstance(param, torch.Tensor) + and not isinstance(param, ShardedTensor) + and param.device.type == "meta" + ): module._parameters[name] = nn.Parameter( torch.empty_like(param, device=self.device), requires_grad=param.requires_grad, @@ -449,22 +452,48 @@ def _named_parameters( prefix: str = "", recurse: bool = True, strip_ddp: bool = True, + include_fused: bool = False, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: if strip_ddp: module = get_unwrapped_module(module) if isinstance(module, ShardedModule): - yield from module.named_parameters(prefix, recurse) + if isinstance(module, ShardedEmbeddingBagCollection): + yield from module.named_parameters( + prefix, recurse, include_fused=include_fused + ) + else: + yield from module.named_parameters( + prefix, + recurse, + ) else: yield from module.named_parameters(prefix, recurse=False) for name, child in module.named_children(): yield from self._named_parameters( - child, append_prefix(prefix, name), recurse, strip_ddp + child, + append_prefix(prefix, name), + recurse, + strip_ddp, + include_fused=include_fused, ) def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + self, + prefix: str = "", + recurse: bool = True, + remove_duplicate: bool = True, + # TODO remove when note needed + include_fused: bool = False, ) -> Iterator[Tuple[str, torch.nn.Parameter]]: - gen = self._named_parameters(self.module, prefix, recurse) + """ + Args: + prefix (str): + recurse (bool): + include_fused (bool): flag for whether or not to include fused parameters. set to False for backward compatibility (of not returning fused) + """ + gen = self._named_parameters( + self.module, prefix, recurse, include_fused=include_fused + ) memo = set() for key, param in gen: if param in memo: diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 61e031d57..fb836d7ed 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -43,6 +43,7 @@ from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.test_utils import assert_state_buffers_parameters_equal from typing_extensions import Protocol @@ -337,7 +338,6 @@ def sharding_single_rank_test( sharders=sharders, device=ctx.device, ) - dense_optim = KeyedOptimizerWrapper( dict(local_model.named_parameters()), lambda params: torch.optim.SGD(params, lr=0.1), diff --git a/torchrec/distributed/tests/test_fused_embedding_bag_collection.py b/torchrec/distributed/tests/test_fused_embedding_bag_collection.py index 0905c6638..a6716fef3 100644 --- a/torchrec/distributed/tests/test_fused_embedding_bag_collection.py +++ b/torchrec/distributed/tests/test_fused_embedding_bag_collection.py @@ -111,7 +111,7 @@ class FusedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_fused_ebc( self, sharder_type: str, @@ -168,7 +168,7 @@ def test_sharding_fused_ebc( ] ), ) - @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + @settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None) def test_sharding_fused_ebc_module_replace( self, sharding_type: str, diff --git a/torchrec/distributed/tests/test_model_parallel.py b/torchrec/distributed/tests/test_model_parallel.py index b2e77a3e0..e290ceda6 100644 --- a/torchrec/distributed/tests/test_model_parallel.py +++ b/torchrec/distributed/tests/test_model_parallel.py @@ -75,7 +75,7 @@ class ModelParallelTest(ModelParallelTestShared): ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -278,7 +278,7 @@ def test_sharding_nccl_cw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -352,7 +352,7 @@ def test_sharding_nccl_tw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -418,7 +418,7 @@ def test_sharding_gloo_tw( ), kernel_type=st.sampled_from( [ - EmbeddingComputeKernel.DENSE.value, + # EmbeddingComputeKernel.DENSE.value, EmbeddingComputeKernel.FUSED.value, ] ), @@ -510,6 +510,9 @@ def test_sharding_gloo_dp( class ModelParallelSparseOnlyTest(unittest.TestCase): + def tearDown(self) -> None: + dist.destroy_process_group() + def test_sharding_ebc_as_top_level(self) -> None: os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" @@ -545,46 +548,6 @@ def test_sharding_ebc_as_top_level(self) -> None: model = DistributedModelParallel(ebc, device=curr_device) self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection)) - dist.destroy_process_group() - - def test_sharding_fused_ebc_as_top_level(self) -> None: - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - if torch.cuda.is_available(): - curr_device = torch.device("cuda:0") - torch.cuda.set_device(curr_device) - backend = "nccl" - else: - curr_device = torch.device("cpu") - backend = "gloo" - dist.init_process_group(backend=backend) - - embedding_dim = 128 - num_embeddings = 256 - ebc = FusedEmbeddingBagCollection( - device=torch.device("meta"), - tables=[ - EmbeddingBagConfig( - name="large_table", - embedding_dim=embedding_dim, - num_embeddings=num_embeddings, - feature_names=["my_feature"], - pooling=PoolingType.SUM, - ), - ], - optimizer_type=torch.optim.SGD, - optimizer_kwargs={"lr": 0.02}, - ) - - model = DistributedModelParallel(ebc, device=curr_device) - - self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection)) - dist.destroy_process_group() class ModelParallelStateDictTest(unittest.TestCase): @@ -900,9 +863,8 @@ def test_params_and_buffers( ] # pyre-ignore[6] (m, _), batch = self._generate_dmps_and_batch(sharders=sharders) - print(f"Sharding Plan: {m._plan}") state_dict_keys = set(m.state_dict().keys()) - param_keys = {key for (key, _) in m.named_parameters()} + param_keys = {key for (key, _) in m.named_parameters(include_fused=True)} buffer_keys = {key for (key, _) in m.named_buffers()} self.assertEqual(state_dict_keys, {*param_keys, *buffer_keys}) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index 2f15d2114..c668c70ba 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -78,7 +78,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]: def compute_kernels( self, sharding_type: str, compute_device_type: str ) -> List[str]: - return [EmbeddingComputeKernel.DENSE.value] + return [EmbeddingComputeKernel.FUSED.value] @dataclass @@ -270,7 +270,12 @@ def test_position_weighted_feature_processor(self) -> None: env=ShardingEnv.from_process_group(self.pg), init_data_parallel=True, device=self.device, - sharders=[cast(ModuleSharder[nn.Module], TestCustomEBCSharder())], + sharders=[ + cast( + ModuleSharder[nn.Module], + TestCustomEBCSharder(fused_params={"learning_rate": 0.1}), + ) + ], ) test_unsharded_model = TestSparseNN( tables=self.tables + fp_tables, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index c71260f34..e132ecfd2 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -662,6 +662,13 @@ def loop(key: str, modules: List[nn.Module]) -> List[str]: return "\n ".join(rep) + def _initialize_torch_state(self) -> None: + """ + This provides consistency between this class and the ShardedModule's + nn.Module API calls (state_dict, named_modules, etc) + """ + pass + class ModuleSharder(abc.ABC, Generic[M]): """ diff --git a/torchrec/test_utils/__init__.py b/torchrec/test_utils/__init__.py index ecc799d5b..c9d0d1a32 100644 --- a/torchrec/test_utils/__init__.py +++ b/torchrec/test_utils/__init__.py @@ -12,12 +12,13 @@ import time from contextlib import closing from functools import wraps -from typing import Callable, Optional, TypeVar +from typing import Any, Callable, Dict, Optional, TypeVar import numpy as np import torch import torch.distributed as dist from pyre_extensions import ParameterSpecification +from torch import nn TParams = ParameterSpecification("TParams") TReturn = TypeVar("TReturn") @@ -95,3 +96,46 @@ def _wrapper(*args, **kwargs): return wrapped_func(*args, **kwargs) return _wrapper + + +def get_state_buffers_parameters(model: nn.Module) -> Dict[str, Any]: + return { + "state_dict": model.state_dict(), + "named_buffers": dict(model.named_buffers()), + "named_parameters": dict(model.named_parameters()), + } + + +def assert_state_buffers_parameters_equal( + model_1: nn.Module, + model_2: nn.Module, + check_named_buffers: bool = True, + check_named_parameters: bool = True, + check_state_dict: bool = True, +) -> None: + """ + Checks to see if the keys of top level PyTorch API calls are the same + between two modules. + """ + + model_characteristics = {} + model_characteristics["model_1"] = get_state_buffers_parameters(model_1) + model_characteristics["model_2"] = get_state_buffers_parameters(model_2) + + assert ( + not check_named_buffers + or model_characteristics["model_1"]["named_buffers"].keys() + == model_characteristics["model_2"]["named_buffers"].keys() + ), "named buffers keys are not the same" + + assert ( + not check_named_parameters + or model_characteristics["model_1"]["named_parameters"].keys() + == model_characteristics["model_2"]["named_parameters"].keys() + ), "named parameter keys are not the same" + + assert ( + not check_state_dict + or model_characteristics["model_1"]["state_dict"].keys() + == model_characteristics["model_2"]["state_dict"].keys() + ), f"state dict key are not the same, {model_characteristics['model_1']['state_dict'].keys()} vs {model_characteristics['model_2']['state_dict'].keys()}"