From 9d45531a2799839cd36b1fc09186b1abfe134fd0 Mon Sep 17 00:00:00 2001 From: Justin Yang Date: Wed, 12 Nov 2025 04:08:50 -0800 Subject: [PATCH] Back out "Add row based sharding support for FeaturedProcessedEBC" (#3537) Summary: Original commit changeset: 4a8ad3bc6d14 Original Phabricator Diff: D82248545 Reviewed By: sarckk, really121, aliafzal Differential Revision: D86779963 --- torchrec/distributed/fp_embeddingbag.py | 59 +------ .../distributed/tests/test_fp_embeddingbag.py | 1 + .../tests/test_fp_embeddingbag_utils.py | 7 +- .../tests/test_train_pipelines.py | 165 +----------------- torchrec/distributed/train_pipeline/utils.py | 3 - torchrec/distributed/types.py | 20 --- torchrec/distributed/utils.py | 51 +----- torchrec/modules/feature_processor_.py | 49 ++---- 8 files changed, 25 insertions(+), 330 deletions(-) diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 3d7fd4140..4b069437f 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -8,18 +8,7 @@ # pyre-strict from functools import partial -from typing import ( - Any, - Dict, - Iterator, - List, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Dict, Iterator, List, Optional, Type, Union import torch from torch import nn @@ -42,11 +31,7 @@ ShardingEnv, ShardingType, ) -from torchrec.distributed.utils import ( - append_prefix, - init_parameters, - modify_input_for_feature_processor, -) +from torchrec.distributed.utils import append_prefix, init_parameters from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.modules.fp_embedding_modules import ( apply_feature_processors_to_kjt, @@ -54,8 +39,6 @@ ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor -_T = TypeVar("_T") - def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor: kt._values.add_(no_op_tensor) @@ -91,16 +74,6 @@ def __init__( ) ) - self._row_wise_sharded: bool = False - for param_sharding in table_name_to_parameter_sharding.values(): - if param_sharding.sharding_type in [ - ShardingType.ROW_WISE.value, - ShardingType.TABLE_ROW_WISE.value, - ShardingType.GRID_SHARD.value, - ]: - self._row_wise_sharded = True - break - self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups self._is_collection: bool = False @@ -123,11 +96,6 @@ def __init__( def input_dist( self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: - if not self.is_pipelined and self._row_wise_sharded: - # transform input to support row based sharding when not pipelined - modify_input_for_feature_processor( - features, self._feature_processors, self._is_collection - ) return self._embedding_bag_collection.input_dist(ctx, features) def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: @@ -137,7 +105,10 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: kjt_list.append(self._feature_processors(features)) else: kjt_list.append( - apply_feature_processors_to_kjt(features, self._feature_processors) + apply_feature_processors_to_kjt( + features, + self._feature_processors, + ) ) return KJTList(kjt_list) @@ -146,6 +117,7 @@ def compute( ctx: EmbeddingBagCollectionContext, dist_input: KJTList, ) -> List[torch.Tensor]: + fp_features = self.apply_feature_processors_to_kjt_list(dist_input) return self._embedding_bag_collection.compute(ctx, fp_features) @@ -194,18 +166,6 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa self._embedding_bag_collection._initialize_torch_state(skip_registering) - def preprocess_input( - self, args: List[_T], kwargs: Mapping[str, _T] - ) -> Tuple[List[_T], Mapping[str, _T]]: - for x in args + list(kwargs.values()): - if isinstance(x, KeyedJaggedTensor): - modify_input_for_feature_processor( - features=x, - feature_processors=self._feature_processors, - is_collection=self._is_collection, - ) - return args, kwargs - class FeatureProcessedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection] @@ -231,6 +191,7 @@ def shard( device: Optional[torch.device] = None, module_fqn: Optional[str] = None, ) -> ShardedFeatureProcessedEmbeddingBagCollection: + if device is None: device = torch.device("cuda") @@ -267,14 +228,12 @@ def sharding_types(self, compute_device_type: str) -> List[str]: if compute_device_type in {"mtia"}: return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value] + # No row wise because position weighted FP and RW don't play well together. types = [ ShardingType.DATA_PARALLEL.value, ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, - ShardingType.TABLE_ROW_WISE.value, - ShardingType.ROW_WISE.value, - ShardingType.GRID_SHARD.value, ] return types diff --git a/torchrec/distributed/tests/test_fp_embeddingbag.py b/torchrec/distributed/tests/test_fp_embeddingbag.py index 08f5dfdbb..130776919 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag.py @@ -231,6 +231,7 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): def test_sharding_ebc( self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool ) -> None: + import hypothesis # don't need to test entire matrix diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py index e39ed4310..b4cee4070 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py @@ -86,12 +86,7 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]: pred = torch.cat( [ fp_ebc_out[key] - for key in [ - "feature_0", - "feature_1", - "feature_2", - "feature_3", - ] + for key in ["feature_0", "feature_1", "feature_2", "feature_3"] ], dim=1, ) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 08c519958..e5c5d5d7f 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -22,10 +22,7 @@ from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - EmbeddingTableConfig, -) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.fp_embeddingbag import ( FeatureProcessedEmbeddingBagCollectionSharder, @@ -34,13 +31,8 @@ from torchrec.distributed.model_parallel import DMPCollection from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, - row_wise, table_wise, ) -from torchrec.distributed.test_utils.multi_process import ( - MultiProcessContext, - MultiProcessTestBase, -) from torchrec.distributed.test_utils.test_model import ( ModelInput, TestNegSamplingModule, @@ -341,161 +333,6 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: torch.testing.assert_close(pred_gpu.cpu(), pred) -def fp_ebc_rw_sharding_test_runner( - rank: int, - world_size: int, - tables: List[EmbeddingTableConfig], - weighted_tables: List[EmbeddingTableConfig], - data: List[Tuple[ModelInput, List[ModelInput]]], - backend: str = "nccl", - local_size: Optional[int] = None, -) -> None: - with MultiProcessContext(rank, world_size, backend, local_size) as ctx: - assert ctx.pg is not None - sharder = cast( - ModuleSharder[nn.Module], - FeatureProcessedEmbeddingBagCollectionSharder(), - ) - - class DummyWrapper(nn.Module): - def __init__(self, sparse_arch): - super().__init__() - self.m = sparse_arch - - def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: - return self.m(model_input.idlist_features) - - max_feature_lengths = [10, 10, 12, 12] - sparse_arch = DummyWrapper( - create_module_and_freeze( - tables=tables, # pyre-ignore[6] - device=ctx.device, - use_fp_collection=False, - max_feature_lengths=max_feature_lengths, - ) - ) - - # compute_kernel = EmbeddingComputeKernel.FUSED.value - module_sharding_plan = construct_module_sharding_plan( - sparse_arch.m._fp_ebc, - per_param_sharding={ - "table_0": row_wise(), - "table_1": row_wise(), - "table_2": row_wise(), - "table_3": row_wise(), - }, - world_size=2, - device_type=ctx.device.type, - sharder=sharder, - ) - sharded_sparse_arch_pipeline = DistributedModelParallel( - module=copy.deepcopy(sparse_arch), - plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), - env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] - sharders=[sharder], - device=ctx.device, - ) - sharded_sparse_arch_no_pipeline = DistributedModelParallel( - module=copy.deepcopy(sparse_arch), - plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), - env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] - sharders=[sharder], - device=ctx.device, - ) - - batches = [] - for d in data: - batches.append(d[1][ctx.rank].to(ctx.device)) - dataloader = iter(batches) - - optimizer_no_pipeline = optim.SGD( - sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 - ) - optimizer_pipeline = optim.SGD( - sharded_sparse_arch_pipeline.parameters(), lr=0.1 - ) - - pipeline = TrainPipelineSparseDist( - sharded_sparse_arch_pipeline, - optimizer_pipeline, - ctx.device, - ) - - for batch in batches[:-2]: - batch = batch.to(ctx.device) - optimizer_no_pipeline.zero_grad() - loss, pred = sharded_sparse_arch_no_pipeline(batch) - loss.backward() - optimizer_no_pipeline.step() - - pred_pipeline = pipeline.progress(dataloader) - torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) - - -class TrainPipelineGPUTest(MultiProcessTestBase): - def setUp(self, backend: str = "nccl") -> None: - super().setUp() - - self.pipeline_class = TrainPipelineSparseDist - num_features = 4 - num_weighted_features = 4 - self.tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 100, - embedding_dim=(i + 1) * 4, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(num_features) - ] - self.weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=(i + 1) * 100, - embedding_dim=(i + 1) * 4, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(num_weighted_features) - ] - - self.backend = backend - if torch.cuda.is_available(): - self.device = torch.device("cuda") - else: - self.device = torch.device("cpu") - - if self.backend == "nccl" and self.device == torch.device("cpu"): - self.skipTest("NCCL not supported on CPUs.") - - def _generate_data( - self, - num_batches: int = 5, - batch_size: int = 1, - max_feature_lengths: Optional[List[int]] = None, - ) -> List[Tuple[ModelInput, List[ModelInput]]]: - return [ - ModelInput.generate( - tables=self.tables, - weighted_tables=self.weighted_tables, - batch_size=batch_size, - world_size=2, - num_float_features=10, - max_feature_lengths=max_feature_lengths, - ) - for i in range(num_batches) - ] - - def test_fp_ebc_rw(self) -> None: - data = self._generate_data(max_feature_lengths=[10, 10, 12, 12]) - self._run_multi_process_test( - callable=fp_ebc_rw_sharding_test_runner, - world_size=2, - tables=self.tables, - weighted_tables=self.weighted_tables, - data=data, - ) - - class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index ce1a30554..8bea1ff37 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -169,7 +169,6 @@ def _start_data_dist( # and this info was done in the _rewrite_model by tracing the # entire model to get the arg_info_list args, kwargs = forward.args.build_args_kwargs(batch) - args, kwargs = module.preprocess_input(args, kwargs) # Start input distribution. module_ctx = module.create_context() @@ -405,8 +404,6 @@ def _rewrite_model( # noqa C901 logger.info(f"Module '{node.target}' will be pipelined") child = sharded_modules[node.target] original_forwards.append(child.forward) - # Set pipelining flag on the child module - child.is_pipelined = True # pyre-ignore[8] Incompatible attribute type child.forward = pipelined_forward( node.target, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 9e8c0cc83..5bac4e396 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -19,10 +19,7 @@ Generic, Iterator, List, - Mapping, Optional, - ParamSpec, - Sequence, Tuple, Type, TypeVar, @@ -82,8 +79,6 @@ class GenericMeta(type): ) from torchrec.streamable import Multistreamable -_T = TypeVar("_T") - def _tabulate( table: List[List[Union[str, int]]], headers: Optional[List[str]] = None @@ -1041,8 +1036,6 @@ def __init__( if qcomm_codecs_registry is None: qcomm_codecs_registry = {} self._qcomm_codecs_registry = qcomm_codecs_registry - # In pipelining, this flag is flipped in rewrite_model when the forward is replaced with the pipelined forward - self.is_pipelined = False @abc.abstractmethod def create_context(self) -> ShrdCtx: @@ -1145,19 +1138,6 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key - def preprocess_input( - self, - args: List[_T], - kwargs: Mapping[str, _T], - ) -> Tuple[List[_T], Mapping[str, _T]]: - """ - This function can be used to preprocess the input arguments prior to module forward call. - - For example, it is used in ShardedFeatureProcessorEmbeddingBagCollection to transform the input data - prior to the forward call. - """ - return args, kwargs - @property @abc.abstractmethod def unsharded_module_type(self) -> Type[nn.Module]: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index f2dbefa91..b12660e97 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -26,10 +26,8 @@ from torch import nn from torch.autograd.profiler import record_function from torchrec import optim as trec_optim -from torchrec.distributed.embedding_types import ( - EmbeddingComputeKernel, - KeyedJaggedTensor, -) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel + from torchrec.distributed.types import ( DataType, EmbeddingEvent, @@ -40,7 +38,6 @@ ShardMetadata, ) from torchrec.modules.embedding_configs import data_type_to_sparse_type -from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.types import CopyMixIn logger: logging.Logger = logging.getLogger(__name__) @@ -761,47 +758,3 @@ def _recalculate_torch_state_helper( _recalculate_torch_state_helper(child) _recalculate_torch_state_helper(module) - - -def modify_input_for_feature_processor( - features: KeyedJaggedTensor, - feature_processors: Union[nn.ModuleDict, FeatureProcessorsCollection], - is_collection: bool, -) -> None: - """ - This function applies the feature processor pre input dist. This way we - can support row wise based sharding mechanisms. - - This is an inplace modifcation of the input KJT. - """ - with torch.no_grad(): - if features.weights_or_none() is None: - # force creation of weights, this way the feature jagged tensor weights are tied to the original KJT - features._weights = torch.zeros_like(features.values(), dtype=torch.float32) - - if is_collection: - if hasattr(feature_processors, "pre_process_pipeline_input"): - feature_processors.pre_process_pipeline_input(features) # pyre-ignore[29] - else: - logging.info( - f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processors=}" - ) - else: - # per feature process - for feature in features.keys(): - if feature in feature_processors: # pyre-ignore[58] - feature_processor = feature_processors[feature] # pyre-ignore[29] - if hasattr(feature_processor, "pre_process_pipeline_input"): - feature_processor.pre_process_pipeline_input(features[feature]) - else: - logging.info( - f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processor=}" - ) - else: - features[feature].weights().copy_( - torch.ones( - features[feature].values().shape[0], - dtype=torch.float32, - device=features[feature].values().device, - ) - ) diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index f064ad5e3..707f5bd2b 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -14,7 +14,7 @@ import torch -from torch import distributed as dist, nn +from torch import nn from torch.nn.modules.module import _IncompatibleKeys from torchrec.pt2.checks import is_non_strict_exporting @@ -72,7 +72,6 @@ def __init__( torch.empty([max_feature_length], device=device), requires_grad=True, ) - self.pipelined = False self.reset_parameters() @@ -86,18 +85,15 @@ def forward( ) -> JaggedTensor: """ Args: - features (JaggedTensor): feature representation + features (JaggedTensor]): feature representation Returns: JaggedTensor: same as input features with `weights` field being populated. """ - if self.pipelined: - # position is embedded as weights - seq = features.weights().clone().to(torch.int64) - else: - seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + + seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) weighted_features = JaggedTensor( values=features.values(), lengths=features.lengths(), @@ -106,20 +102,6 @@ def forward( ) return weighted_features - def pre_process_pipeline_input(self, features: JaggedTensor) -> None: - """ - Args: - features (JaggedTensor]): feature representation - - Returns: - torch.Tensor: position weights - """ - self.pipelined = True - cat_seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) - features.weights().copy_(cat_seq.to(torch.float32)) - class FeatureProcessorsCollection(nn.Module): """ @@ -187,7 +169,7 @@ def __init__( for length in self.max_feature_lengths.values(): if length <= 0: raise - self.pipelined = False # if pipelined, input dist has performed part of input feature processing + self.position_weights: nn.ParameterDict = nn.ParameterDict() # needed since nn.ParameterDict isn't torchscriptable (get_items) self.position_weights_dict: Dict[str, nn.Parameter] = {} @@ -209,6 +191,7 @@ def reset_parameters(self) -> None: self.position_weights_dict[key] = self.position_weights[key] def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + # TODO unflattener doesnt work well with aten.to at submodule boundaries if is_non_strict_exporting(): offsets = features.offsets() if offsets.dtype == torch.int64: @@ -220,12 +203,9 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: features.offsets().long(), torch.numel(features.values()) ) else: - if self.pipelined: - cat_seq = features.weights().clone().to(torch.int64) - else: - cat_seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) return KeyedJaggedTensor( keys=features.keys(), @@ -265,10 +245,3 @@ def load_state_dict( for k, param in self.position_weights.items(): self.position_weights_dict[k] = param return result - - def pre_process_pipeline_input(self, features: KeyedJaggedTensor) -> None: - self.pipelined = True - cat_seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) - features.weights().copy_(cat_seq.to(torch.float32))