diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 1bfe95f36..b8be13994 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -58,6 +58,7 @@ StageOut, StageOutputWithEvent, TrainPipelineContext, + use_context_for_postprocs, ) from torchrec.distributed.types import Awaitable from torchrec.pt2.checks import is_torchdynamo_compiling @@ -791,19 +792,9 @@ def start_sparse_data_dist( with self._stream_context(self._data_dist_stream): _wait_for_batch(batch, self._memcpy_stream) - original_contexts = [p.get_context() for p in self._pipelined_postprocs] - # Temporarily set context for next iter to populate cache - for postproc_mod in self._pipelined_postprocs: - postproc_mod.set_context(context) - - _start_data_dist(self._pipelined_modules, batch, context) - - # Restore context for model fwd - for module, context in zip( - self._pipelined_postprocs, original_contexts - ): - module.set_context(context) + with use_context_for_postprocs(self._pipelined_postprocs, context): + _start_data_dist(self._pipelined_modules, batch, context) def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: """ @@ -1324,22 +1315,15 @@ def start_sparse_data_dist( return # Temporarily set context for next iter to populate cache - original_contexts = [p.get_context() for p in self._pipelined_postprocs] - for postproc_mod in self._pipelined_postprocs: - postproc_mod.set_context(context) - - with record_function(f"## start_sparse_data_dist {context.index} ##"): - with self._stream_context(self._data_dist_stream): - _wait_for_events(batch, context, self._data_dist_stream) - model_input = self.extract_model_input_from_batch(batch) - _start_data_dist(self._pipelined_modules, model_input, context) - event = torch.get_device_module(self._device).Event() - event.record() - context.events.append(event) - - # Restore context for model forward - for module, context in zip(self._pipelined_postprocs, original_contexts): - module.set_context(context) + with use_context_for_postprocs(self._pipelined_postprocs, context): + with record_function(f"## start_sparse_data_dist {context.index} ##"): + with self._stream_context(self._data_dist_stream): + _wait_for_events(batch, context, self._data_dist_stream) + model_input = self.extract_model_input_from_batch(batch) + _start_data_dist(self._pipelined_modules, model_input, context) + event = torch.get_device_module(self._device).Event() + event.record() + context.events.append(event) def start_embedding_lookup( self, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 39a04e48f..2a561f80a 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -7,6 +7,8 @@ # pyre-strict import abc + +import contextlib import copy import itertools import logging @@ -21,6 +23,7 @@ Callable, cast, Dict, + Generator, Generic, Iterable, Iterator, @@ -1834,6 +1837,28 @@ def _prefetch_embeddings( return data_per_sharded_module +@contextlib.contextmanager +def use_context_for_postprocs( + pipelined_postprocs: List[PipelinedPostproc], + next_batch_context: TrainPipelineContext, +) -> Generator[None, None, None]: + """ + Temporarily set pipelined postproc context for next iter to populate cache. + """ + # Save original context for model fwd + original_contexts = [p.get_context() for p in pipelined_postprocs] + + # Temporarily set context for next iter to populate cache + for postproc_mod in pipelined_postprocs: + postproc_mod.set_context(next_batch_context) + + yield + + # Restore context for model fwd + for module, context in zip(pipelined_postprocs, original_contexts): + module.set_context(context) + + class SparseDataDistUtil(Generic[In]): """ Helper class exposing methods for sparse data dist and prefetch pipelining. @@ -1845,6 +1870,7 @@ class SparseDataDistUtil(Generic[In]): apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs Defaults to `None`. This needs to be passed in to enable prefetch pipelining. + pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`. Example:: sdd = SparseDataDistUtil(