From a68abff505b15224faa8c47386a058ea044826ca Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Mon, 3 Mar 2025 08:40:06 -0800 Subject: [PATCH] Support MCH for semi-sync (assuming no eviction) (#2753) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2753 ZCH modules return a tuple of awaitables for embeddings and remapped KJTs. Update semi-sync training code to account for this. Reviewed By: dstaay-fb Differential Revision: D69861054 --- .../distributed/composable/tests/test_ddp.py | 2 + .../distributed/composable/tests/test_fsdp.py | 2 + torchrec/distributed/test_utils/test_model.py | 217 +++++++++++++++++- .../tests/test_train_pipelines.py | 10 +- .../tests/test_train_pipelines_base.py | 14 +- torchrec/distributed/train_pipeline/utils.py | 48 +++- 6 files changed, 273 insertions(+), 20 deletions(-) diff --git a/torchrec/distributed/composable/tests/test_ddp.py b/torchrec/distributed/composable/tests/test_ddp.py index b6291afc9..60472bef3 100644 --- a/torchrec/distributed/composable/tests/test_ddp.py +++ b/torchrec/distributed/composable/tests/test_ddp.py @@ -105,11 +105,13 @@ def _run(cls, rank: int, world_size: int, path: str) -> None: weighted_tables=weighted_tables, dense_device=ctx.device, ) + # pyre-ignore m.sparse.ebc = trec_shard( module=m.sparse.ebc, device=ctx.device, plan=column_wise(ranks=list(range(world_size))), ) + # pyre-ignore m.sparse.weighted_ebc = trec_shard( module=m.sparse.weighted_ebc, device=ctx.device, diff --git a/torchrec/distributed/composable/tests/test_fsdp.py b/torchrec/distributed/composable/tests/test_fsdp.py index aae30dde3..538b5382f 100644 --- a/torchrec/distributed/composable/tests/test_fsdp.py +++ b/torchrec/distributed/composable/tests/test_fsdp.py @@ -83,11 +83,13 @@ def _run( # noqa m.sparse.parameters(), {"lr": 0.01}, ) + # pyre-ignore m.sparse.ebc = trec_shard( module=m.sparse.ebc, device=ctx.device, plan=row_wise(), ) + # pyre-ignore m.sparse.weighted_ebc = trec_shard( module=m.sparse.weighted_ebc, device=ctx.device, diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index db839f6ff..bb902ce02 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -26,7 +26,18 @@ ) from torchrec.distributed.fused_embedding import FusedEmbeddingCollectionSharder from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder -from torchrec.distributed.types import QuantizedCommCodecs +from torchrec.distributed.mc_embedding_modules import ( + BaseManagedCollisionEmbeddingCollectionSharder, +) +from torchrec.distributed.mc_embeddingbag import ( + ShardedManagedCollisionEmbeddingBagCollection, +) +from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder +from torchrec.distributed.types import ( + ParameterSharding, + QuantizedCommCodecs, + ShardingEnv, +) from torchrec.distributed.utils import CopyableMixin from torchrec.modules.activation import SwishLayerNorm from torchrec.modules.embedding_configs import ( @@ -39,6 +50,12 @@ from torchrec.modules.feature_processor import PositionWeightedProcessor from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.mc_modules import ( + DistanceLFU_EvictionPolicy, + ManagedCollisionCollection, + MCHManagedCollisionModule, +) from torchrec.modules.regroup import KTRegroupAsDict from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Pipelineable @@ -1351,6 +1368,7 @@ def __init__( feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, over_arch_clazz: Type[nn.Module] = TestOverArch, postproc_module: Optional[nn.Module] = None, + zch: bool = False, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -1362,12 +1380,20 @@ def __init__( if weighted_tables is None: weighted_tables = [] self.dense = TestDenseArch(num_float_features, dense_device) - self.sparse = TestSparseArch( - tables, - weighted_tables, - sparse_device, - max_feature_lengths, - ) + if zch: + self.sparse: nn.Module = TestSparseArchZCH( + tables, + weighted_tables, + torch.device("meta"), + return_remapped=True, + ) + else: + self.sparse = TestSparseArch( + tables, + weighted_tables, + sparse_device, + max_feature_lengths, + ) embedding_names = ( list(embedding_groups.values())[0] if embedding_groups else None @@ -1687,6 +1713,64 @@ def compute_kernels( return [self._kernel_type] +class TestMCSharder(ManagedCollisionCollectionSharder): + def __init__( + self, + sharding_type: str, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + self._sharding_type = sharding_type + super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) + + def sharding_types(self, compute_device_type: str) -> List[str]: + return [self._sharding_type] + + +class TestEBCSharderMCH( + BaseManagedCollisionEmbeddingCollectionSharder[ + ManagedCollisionEmbeddingBagCollection + ] +): + def __init__( + self, + sharding_type: str, + kernel_type: str, + fused_params: Optional[Dict[str, Any]] = None, + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__( + TestEBCSharder( + sharding_type, kernel_type, fused_params, qcomm_codecs_registry + ), + TestMCSharder(sharding_type, qcomm_codecs_registry), + qcomm_codecs_registry=qcomm_codecs_registry, + ) + + @property + def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]: + return ManagedCollisionEmbeddingBagCollection + + def shard( + self, + module: ManagedCollisionEmbeddingBagCollection, + params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedManagedCollisionEmbeddingBagCollection: + if device is None: + device = torch.device("cuda") + return ShardedManagedCollisionEmbeddingBagCollection( + module, + params, + # pyre-ignore [6] + ebc_sharder=self._e_sharder, + mc_sharder=self._mc_sharder, + env=env, + device=device, + ) + + class TestFusedEBCSharder(FusedEmbeddingBagCollectionSharder): def __init__( self, @@ -2188,3 +2272,122 @@ def forward(self, input: ModelInput) -> ModelInput: modified_input = copy.deepcopy(input) modified_input.idlist_features = self.fp_proc(modified_input.idlist_features) return modified_input + + +class TestSparseArchZCH(nn.Module): + """ + Basic nn.Module for testing MCH EmbeddingBagCollection + + Args: + tables + weighted_tables + device + return_remapped + + Call Args: + features + weighted_features + batch_size + + Returns: + KeyedTensor + + Example:: + + TestSparseArch() + """ + + def __init__( + self, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + device: torch.device, + return_remapped: bool = False, + ) -> None: + super().__init__() + self._return_remapped = return_remapped + + mc_modules = {} + for table in tables: + mc_modules[table.name] = MCHManagedCollisionModule( + zch_size=table.num_embeddings, + input_hash_size=4000, + device=device, + # TODO: If eviction interval is set to + # a low number (e.g. 2), semi-sync pipeline test will + # fail with in-place modification error during + # loss.backward(). This is because during semi-sync training, + # we run embedding module forward after autograd graph + # is constructed, but if MCH eviction happens, the + # variable used in autograd will have been modified + eviction_interval=1000, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + + self.ebc: ManagedCollisionEmbeddingBagCollection = ( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=tables, + device=device, + ), + ManagedCollisionCollection( + managed_collision_modules=mc_modules, + embedding_configs=tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + self.weighted_ebc: Optional[ManagedCollisionEmbeddingBagCollection] = None + if weighted_tables: + weighted_mc_modules = {} + for table in weighted_tables: + weighted_mc_modules[table.name] = MCHManagedCollisionModule( + zch_size=table.num_embeddings, + input_hash_size=4000, + device=device, + # TODO: Support MCH evictions during semi-sync + eviction_interval=1000, + eviction_policy=DistanceLFU_EvictionPolicy(), + ) + self.weighted_ebc: ManagedCollisionEmbeddingBagCollection = ( + ManagedCollisionEmbeddingBagCollection( + EmbeddingBagCollection( + tables=weighted_tables, + device=device, + is_weighted=True, + ), + ManagedCollisionCollection( + managed_collision_modules=weighted_mc_modules, + embedding_configs=weighted_tables, + ), + return_remapped_features=self._return_remapped, + ) + ) + + def forward( + self, + features: KeyedJaggedTensor, + weighted_features: Optional[KeyedJaggedTensor] = None, + batch_size: Optional[int] = None, + ) -> KeyedTensor: + """ + Runs forward and MC EBC and optionally, weighted MC EBC, + then merges the results into one KeyedTensor + + Args: + features + weighted_features + batch_size + Returns: + KeyedTensor + """ + ebc, _ = self.ebc(features) + ebc = _post_ebc_test_wrap_function(ebc) + w_ebc, _ = ( + self.weighted_ebc(weighted_features) + if self.weighted_ebc is not None and weighted_features is not None + else None + ) + result = _post_sparsenn_forward(ebc, None, w_ebc, batch_size) + return result diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index dc23593a0..03aa3ea96 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -17,7 +17,7 @@ from unittest.mock import MagicMock import torch -from hypothesis import given, settings, strategies as st, Verbosity +from hypothesis import assume, given, settings, strategies as st, Verbosity from torch import nn, optim from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters @@ -1531,7 +1531,7 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", ) - @settings(max_examples=4, deadline=None) + @settings(max_examples=8, deadline=None) # pyre-ignore[56] @given( start_batch=st.sampled_from([0, 6]), @@ -1547,6 +1547,7 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase): EmbeddingComputeKernel.FUSED.value, ] ), + zch=st.booleans(), ) def test_equal_to_non_pipelined( self, @@ -1554,10 +1555,13 @@ def test_equal_to_non_pipelined( stash_gradients: bool, sharding_type: str, kernel_type: str, + zch: bool, ) -> None: """ Checks that pipelined training is equivalent to non-pipelined training. """ + # ZCH only supports row-wise currently + assume(not zch or (zch and sharding_type != ShardingType.TABLE_WISE.value)) torch.autograd.set_detect_anomaly(True) data = self._generate_data( num_batches=12, @@ -1572,7 +1576,7 @@ def test_equal_to_non_pipelined( **fused_params, } - model = self._setup_model() + model = self._setup_model(zch=zch) sharded_model, optim = self._generate_sharded_model_and_optimizer( model, sharding_type, kernel_type, fused_params ) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index a5ed6e7b5..56e6ac636 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -21,6 +21,7 @@ from torchrec.distributed.test_utils.test_model import ( ModelInput, TestEBCSharder, + TestEBCSharderMCH, TestSparseNN, ) from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist @@ -96,6 +97,7 @@ def _setup_model( model_type: Type[nn.Module] = TestSparseNN, enable_fsdp: bool = False, postproc_module: Optional[nn.Module] = None, + zch: bool = False, ) -> nn.Module: unsharded_model = model_type( tables=self.tables, @@ -103,6 +105,7 @@ def _setup_model( dense_device=self.device, sparse_device=torch.device("meta"), postproc_module=postproc_module, + zch=zch, ) if enable_fsdp: unsharded_model.over.dhn_arch.linear0 = FSDP( @@ -135,6 +138,11 @@ def _generate_sharded_model_and_optimizer( kernel_type=kernel_type, fused_params=fused_params, ) + mc_sharder = TestEBCSharderMCH( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) sharded_model = DistributedModelParallel( module=copy.deepcopy(model), env=ShardingEnv.from_process_group(self.pg), @@ -144,7 +152,11 @@ def _generate_sharded_model_and_optimizer( cast( ModuleSharder[nn.Module], sharder, - ) + ), + cast( + ModuleSharder[nn.Module], + mc_sharder, + ), ], ) # default fused optimizer is SGD w/ lr=0.1; we need to drop params diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 8f0c9d569..3c48bd6fd 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -138,9 +138,16 @@ class PrefetchTrainPipelineContext(TrainPipelineContext): @dataclass class EmbeddingTrainPipelineContext(TrainPipelineContext): - embedding_a2a_requests: Dict[str, LazyAwaitable[Multistreamable]] = field( - default_factory=dict - ) + embedding_a2a_requests: Dict[ + str, + Union[ + LazyAwaitable[Multistreamable], + # ManagedCollisionEC/EBC returns tuple of awaitables + Tuple[ + LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]] + ], + ], + ] = field(default_factory=dict) embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list) detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) @@ -493,6 +500,8 @@ def load_state_dict( TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext) +EmbeddingModuleRetType = Union[Dict[str, JaggedTensor], KeyedTensor] + class BaseForward(Generic[TForwardContext]): def __init__( @@ -559,8 +568,18 @@ def __call__(self, *input, **kwargs) -> Awaitable: class EmbeddingPipelinedForward(BaseForward[EmbeddingTrainPipelineContext]): - # pyre-ignore [2, 24] - def __call__(self, *input, **kwargs) -> Awaitable: + def __call__( + self, + # pyre-ignore + *input, + # pyre-ignore + **kwargs, + ) -> Union[ + Awaitable[EmbeddingModuleRetType], + Tuple[ + Awaitable[EmbeddingModuleRetType], Awaitable[Optional[KeyedJaggedTensor]] + ], + ]: assert ( self._name in self._context.embedding_a2a_requests ), "Invalid EmbeddingPipelinedForward usage, please do not directly call model.forward()" @@ -574,7 +593,15 @@ def __call__(self, *input, **kwargs) -> Awaitable: ) ctx.record_stream(cur_stream) awaitable = self._context.embedding_a2a_requests.pop(self._name) - embeddings = awaitable.wait() # trigger awaitable manually for type checking + remapped_kjts: Optional[KeyedJaggedTensor] = None + if isinstance(awaitable, Iterable): + embeddings = awaitable[0].wait() + remapped_kjts = awaitable[1].wait() + else: + assert isinstance(awaitable, Awaitable) + embeddings = ( + awaitable.wait() + ) # trigger awaitable manually for type checking tensors = [] detached_tensors = [] if isinstance(embeddings, Dict): @@ -608,7 +635,10 @@ def __call__(self, *input, **kwargs) -> Awaitable: self._context.embedding_features.append([list(embeddings.keys())]) self._context.detached_embedding_tensors.append(detached_tensors) - return LazyNoWait(embeddings) + if isinstance(awaitable, Iterable): + return (LazyNoWait(embeddings), LazyNoWait(remapped_kjts)) + else: + return LazyNoWait(embeddings) class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]): @@ -821,8 +851,8 @@ def _start_embedding_lookup( if target_stream is not None: kjt.record_stream(target_stream) module_context.record_stream(target_stream) - a2a_awaitable = module.compute_and_output_dist(module_context, kjt) - context.embedding_a2a_requests[module.forward.name] = a2a_awaitable + output_dist_out = module.compute_and_output_dist(module_context, kjt) + context.embedding_a2a_requests[module.forward.name] = output_dist_out def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: