From c9b21704208cd1f579872219b2fcd5ae21ff59be Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Thu, 1 May 2025 01:21:04 -0700 Subject: [PATCH] provide basic integration of trec 2d emb to training pipeline (#2929) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2929 * add `dmp_collection_sync_interval_batches` as a config value to SDD pipeline (and semi sync) * default to 1 (every batch). * disable using `None` * disabled if DMPC not used Differential Revision: D70988786 --- torchrec/distributed/embedding_types.py | 7 +- torchrec/distributed/quant_state.py | 5 +- .../tests/test_train_pipelines.py | 108 +++++++++++++++++- .../train_pipeline/train_pipelines.py | 75 ++++++++++++ torchrec/distributed/train_pipeline/utils.py | 9 +- torchrec/distributed/types.py | 8 ++ torchrec/distributed/utils.py | 11 ++ 7 files changed, 212 insertions(+), 11 deletions(-) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index e8da4a6da..40d9f2308 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -15,6 +15,7 @@ Any, Dict, Generic, + Iterable, Iterator, List, Optional, @@ -329,7 +330,7 @@ def shardings(self) -> Dict[str, FeatureShardingMixIn]: Out = TypeVar("Out") -CompIn = TypeVar("CompIn") +CompIn = TypeVar("CompIn", KJTList, ListOfKJTList, KeyedJaggedTensor) DistOut = TypeVar("DistOut") ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) @@ -358,14 +359,14 @@ def __init__( def prefetch( self, - dist_input: KJTList, + dist_input: CompIn, forward_stream: Optional[Union[torch.cuda.Stream, torch.mtia.Stream]] = None, ctx: Optional[ShrdCtx] = None, ) -> None: """ Prefetch input features for each lookup module. """ - + assert isinstance(dist_input, Iterable) for feature, emb_lookup in zip(dist_input, self._lookups): while isinstance(emb_lookup, DistributedDataParallel): emb_lookup = emb_lookup.module diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 1de388e1b..c9af2afbe 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -26,6 +26,9 @@ from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo from torchrec.distributed.embedding_types import ( GroupedEmbeddingConfig, + KeyedJaggedTensor, + KJTList, + ListOfKJTList, ShardedEmbeddingModule, ) from torchrec.distributed.types import ParameterSharding, ShardingType @@ -34,7 +37,7 @@ from torchrec.tensor_types import UInt2Tensor, UInt4Tensor Out = TypeVar("Out") -CompIn = TypeVar("CompIn") +CompIn = TypeVar("CompIn", KJTList, ListOfKJTList, KeyedJaggedTensor) DistOut = TypeVar("DistOut") ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index a563281ca..3b2b94196 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -10,11 +10,11 @@ import copy import unittest -from contextlib import ExitStack +from contextlib import contextmanager, ExitStack from dataclasses import dataclass from functools import partial -from typing import cast, Dict, List, Optional, Tuple, Type, Union -from unittest.mock import MagicMock +from typing import cast, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from unittest.mock import MagicMock, patch import torch from hypothesis import assume, given, settings, strategies as st, Verbosity @@ -29,6 +29,7 @@ FeatureProcessedEmbeddingBagCollectionSharder, ShardedFeatureProcessedEmbeddingBagCollection, ) +from torchrec.distributed.model_parallel import DMPCollection from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, table_wise, @@ -86,6 +87,9 @@ from torchrec.streamable import Pipelineable +T = TypeVar("T") + + @dataclass class ModelInputSimple(Pipelineable): float_features: torch.Tensor @@ -692,6 +696,104 @@ def custom_model_fwd( self.assertEqual(pred_pipeline.size(0), 64) +class TrainPipelineSparseDist2DShardingTest(unittest.TestCase): + @contextmanager + def _mocked_pipeline(self, obj: T) -> Generator[T, None, None]: + disabled_methods = [ + "fill_pipeline", + "_wait_for_batch", + "enqueue_batch", + "dequeue_batch", + "start_sparse_data_dist", + "wait_sparse_data_dist", + ] + + with ExitStack() as stack: + for method in disabled_methods: + stack.enter_context( + patch.object(obj.__class__, method, return_value=None) + ) + yield obj + + def test_dmp_collection_sync(self) -> None: + dmp = MagicMock(spec=DMPCollection) + dmp.training = True + dmp.return_value = ( + torch.tensor(0.1, requires_grad=True), + torch.tensor(2), + ) # loss, output + + optimizer = MagicMock(spec=torch.optim.Optimizer) + data_iter = MagicMock() + mock_data: MagicMock = MagicMock(spec=Pipelineable) + + def _add_context(pipeline: TrainPipelineSparseDist) -> None: # pyre-ignore + context = TrainPipelineContext() + context.index = 10 + for _ in range(3): + pipeline.batches.append(mock_data) + pipeline.contexts.append(context) + + # disable + pipeline = TrainPipelineSparseDist( + dmp, + optimizer, + device=torch.device("cpu"), + dmp_collection_sync_interval_batches=None, + ) + _add_context(pipeline) + with self._mocked_pipeline(pipeline): + pipeline.progress(data_iter) + + dmp.sync.assert_not_called() + + # enable + dmp.reset_mock() + pipeline_with_dmp_sync = TrainPipelineSparseDist( + dmp, + optimizer, + device=torch.device("cpu"), + dmp_collection_sync_interval_batches=10, + ) + _add_context(pipeline_with_dmp_sync) + with self._mocked_pipeline(pipeline_with_dmp_sync): + pipeline_with_dmp_sync.progress(data_iter) + + dmp.assert_called_once() + dmp.sync.assert_called_once() + + def test_sync_disabled_if_dmp_collection_is_not_used(self) -> None: + dmp = MagicMock(spec=DistributedModelParallel) + dmp.training = True + dmp.return_value = ( + torch.tensor(0.1, requires_grad=True), + torch.tensor(2), + ) # loss, output + + optimizer = MagicMock(spec=torch.optim.Optimizer) + data_iter = MagicMock() + mock_data: MagicMock = MagicMock(spec=Pipelineable) + + # set interval but pass in raw DMP + # interval will be ignored + pipeline = TrainPipelineSparseDist( + dmp, + optimizer, + device=torch.device("cpu"), + dmp_collection_sync_interval_batches=10, + ) + context = TrainPipelineContext() + context.index = 10 + for _ in range(3): + pipeline.batches.append(mock_data) + pipeline.contexts.append(context) + with self._mocked_pipeline(pipeline): + # no exception + pipeline.progress(data_iter) + + dmp.assert_called_once() + + class TrainPipelineAttachDetachTest(TrainPipelineSparseDistTestBase): @unittest.skipIf( not torch.cuda.is_available(), diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 4685fae9c..b35f68568 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -76,6 +76,17 @@ torch.ops.import_module("fbgemm_gpu.sparse_ops") +# Note: doesn't make much sense but better than throwing. +# Somehow some users would mess up their dependency when using torch package, +# and we cannot fix their problem sorry +has_2d_support = True +try: + from torchrec.distributed.model_parallel import DMPCollection +except ImportError: + logger.warning("DMPCollection is not available. 2D sharding is not supported.") + has_2d_support = False + + class ModelDetachedException(Exception): pass @@ -85,6 +96,39 @@ class TrainPipeline(abc.ABC, Generic[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: pass + def sync_embeddings( + self, + model: torch.nn.Module, + interval_batches: Optional[int], + context: Optional[TrainPipelineContext] = None, + ) -> None: + """ + Sync the embedding weights and fused optimizer states across replicas. + Only enabled if DMPCollection is used to shard the model. + Otherwise this is a no op. + """ + if ( + not has_2d_support + or not isinstance(model, DMPCollection) + or interval_batches is None + ): + return + + if not context: + logger.warning( + f"{self.__class__.__name__} does not support context (not expected). " + "Embedding weight sync is disabled." + ) + return + + index = context.index + assert ( + index is not None + ), f"{self.__class__.__name__} context does not provide number of batches: {context=}" + if index % interval_batches == 0: + with record_function("## dmp_collection_sync ##"): + model.sync() + @dataclass class TorchCompileConfig: @@ -343,6 +387,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): execute_all_batches (bool): executes remaining batches in pipeline after exhausting dataloader iterator. apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + dmp_collection_sync_interval_batches (Optional[int]): + (applicable to 2D sharding only) + if set and DMP collection is enabled for 2D sharding, + sync DMPs every N batches (default to 1, i.e. every batch, None to disable) """ # The PipelinedForward class that is used in _rewrite_model @@ -361,6 +409,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + dmp_collection_sync_interval_batches: Optional[int] = 1, ) -> None: self._model = model self._optimizer = optimizer @@ -416,6 +465,15 @@ def __init__( self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = ( custom_model_fwd if custom_model_fwd else model ) + self._pipelined_forward_type = PipelinedForward + self._dmp_collection_sync_interval_batches = ( + dmp_collection_sync_interval_batches + ) + if self._dmp_collection_sync_interval_batches is not None: + logger.info( + f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every " + f"{self._dmp_collection_sync_interval_batches} batches" + ) # DEPRECATED FIELDS self._batch_i: Optional[In] = None @@ -610,6 +668,12 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # backward self._backward(losses) + self.sync_embeddings( + self._model, + self._dmp_collection_sync_interval_batches, + self.contexts[0], + ) + # update with record_function("## optimizer ##"): self._optimizer.step() @@ -823,6 +887,10 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX. stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync" training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync) + dmp_collection_sync_interval_batches (Optional[int]): + (applicable to 2D sharding only) + if set and DMP collection is enabled for 2D sharding, + sync DMPs every N batches (default to 1, i.e. every batch, None to disable) """ # The PipelinedForward class that is used in _rewrite_model @@ -842,6 +910,7 @@ def __init__( Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, strict: bool = False, + dmp_collection_sync_interval_batches: Optional[int] = 1, ) -> None: super().__init__( model=model, @@ -852,6 +921,7 @@ def __init__( context_type=EmbeddingTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -972,6 +1042,11 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # pyre-ignore [6] self.embedding_backward(context) + self.sync_embeddings( + self._model, + self._dmp_collection_sync_interval_batches, + context, + ) del context # context is no longer needed, deleting to free up memory with record_function(f"## optimizer {iteration - 1} ##"): diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index f75a25af4..9554d4e63 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -1731,7 +1731,7 @@ def _prefetch_embeddings( context: PrefetchTrainPipelineContext, pipelined_modules: List[ShardedModule], device: torch.device, - stream_context: torch.Stream, + stream_context: Callable[[Optional[torch.Stream]], torch.cuda.StreamContext], data_dist_stream: Optional[torch.Stream], default_stream: Optional[torch.Stream], ) -> Dict[str, KJTList]: @@ -1739,7 +1739,6 @@ def _prefetch_embeddings( for sharded_module in pipelined_modules: forward = sharded_module.forward assert isinstance(forward, PrefetchPipelinedForward) - assert forward._name in context.input_dist_tensors_requests request = context.input_dist_tensors_requests.pop(forward._name) assert isinstance(request, Awaitable) @@ -1762,10 +1761,12 @@ def _prefetch_embeddings( data, (torch.Tensor, Multistreamable) ), f"{type(data)} must implement Multistreamable interface" data.record_stream(cur_stream) - data.record_stream(default_stream) + if default_stream: + data.record_stream(default_stream) module_context.record_stream(cur_stream) - module_context.record_stream(default_stream) + if default_stream: + module_context.record_stream(default_stream) sharded_module.prefetch( ctx=module_context, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index ac7260d25..f3a0bcb60 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -1003,6 +1003,14 @@ def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut: def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]: pass + def prefetch( + self, + dist_input: CompIn, + forward_stream: Optional[torch.Stream] = None, + ctx: Optional[ShrdCtx] = None, + ) -> None: + return None + def compute_and_output_dist( self, ctx: ShrdCtx, input: CompIn ) -> LazyAwaitable[Out]: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 04a5afe0a..6adca02c0 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -45,6 +45,17 @@ """ +def get_class_name(obj: object) -> str: + if obj is None: + return "None" + return obj.__class__.__name__ + + +def assert_instance(obj: object, t: Type[_T]) -> _T: + assert isinstance(obj, t), f"Got {get_class_name(obj)}" + return obj + + def none_throws(optional: Optional[_T], message: str = "Unexpected `None`") -> _T: """Convert an optional to its value. Raises an `AssertionError` if the value is `None`"""