From bf1a63e7a92137371cf13bd389d78c87bae4607f Mon Sep 17 00:00:00 2001 From: Jingchang Zhang Date: Tue, 2 Sep 2025 19:38:03 -0700 Subject: [PATCH] Add a option to move embedding lookup after sparse data dist in FusedSDD Summary: This diff adds an option to allows the embedding lookup trigger after the sparse data dist. This can potentially improve performance when CPU is blocked by sparse data dist kernel launch and could not launch forward kernel earlier. {F1981658737,width=300} Differential Revision: D81494775 --- .../distributed/train_pipeline/train_pipelines.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 53e7522f3..0df35ac26 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -964,6 +964,7 @@ def __init__( ] = None, strict: bool = False, emb_lookup_stream: str = "data_dist", # new, current, data_dist (default) + embedding_lookup_after_data_dist: bool = False, ) -> None: super().__init__( model=model, @@ -975,6 +976,8 @@ def __init__( pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, ) + self._embedding_lookup_after_data_dist = embedding_lookup_after_data_dist + if emb_lookup_stream == "new": self._emb_lookup_stream: Optional[torch.Stream] = ( (torch.get_device_module(device).Stream()) @@ -1046,8 +1049,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self._set_module_context(self.contexts[0]) # start embedding_lookup so it can overlap with previous optimizer - # pyre-ignore [6] - self.start_embedding_lookup(self.batches[0], self.contexts[0]) + if not self._embedding_lookup_after_data_dist: + # pyre-ignore [6] + self.start_embedding_lookup(self.batches[0], self.contexts[0]) if self._model.training: with record_function("## zero_grad ##"): @@ -1064,6 +1068,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here self.enqueue_batch(dataloader_iter) + if self._embedding_lookup_after_data_dist: + # pyre-ignore [6] + self.start_embedding_lookup(self.batches[0], self.contexts[0]) + # forward with record_function(f"## forward {self.contexts[0].index} ##"): losses, output = self._model_fwd(self.batches[0])