From 91a16d507cf1cad044082cfd768e1f827aa94be1 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 16 Sep 2025 16:53:25 -0700 Subject: [PATCH] refactor PostProc tracing and debug message Summary: # context * There's quite some limitations on the postproc support for TorchRec's train pipeline * add better warning message for debugging ## symptoms * unable to run input_dist in "-1" batch with the `SparseDistTrainPipeline`, AKA, SDD (Sparse Data Dist) pipeline * warning in log: `Module '{node.target}' will NOT be pipelined, due to input modifications` ## typical issues * root cause: input KJT is modified or passed through some module/function potentially modifies the KJT * pipeline_postproc is not enabled * check the error message for `fx node {child_node.name, child_node.op, child_node.target} can't be handled correctly for postproc module` * postproc module has trainable weights (sorry we don't support this) * a postproc function modifies the input KJT * two postproc modules have certain execution order ## workaround * make the postproc function a nn.Module * put order-dependent functions/modules under the same nn.Module to preserve the order. Differential Revision: D82591429 --- .../distributed/train_pipeline/tracing.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tracing.py b/torchrec/distributed/train_pipeline/tracing.py index c4a95756c..f107f784c 100644 --- a/torchrec/distributed/train_pipeline/tracing.py +++ b/torchrec/distributed/train_pipeline/tracing.py @@ -302,7 +302,11 @@ def _handle_module( if not self._pipeline_postproc: logger.warning( - f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" + f"Found module {postproc_module} that potentially modifies input KJT. " + "Train pipeline initialized with `pipeline_postproc=False` (default), " + "so we assume KJT input modification. " + "To allow torchrec to check if this module can be safely pipelined, " + "please set `pipeline_postproc=True`" ) return None @@ -341,11 +345,10 @@ def _handle_module( ) if num_found_safe_postproc_args == total_num_args: logger.info( - f"""Module {postproc_module} is a valid postproc module (no - trainable params and inputs can be derived from train batch input - via a series of either valid postproc modules or non-modifying - transformations) and will be applied during sparse data dist - stage""" + f"Module {postproc_module} is a valid postproc module (no " + "trainable params and inputs can be derived from train batch input " + "via a series of either valid postproc modules or non-modifying " + "transformations) and will be applied during sparse data dist stage" ) pipelined_postproc_module = PipelinedPostproc( @@ -449,6 +452,10 @@ def _get_node_args_helper_inner( arg_info.add_step(ArgInfoStepFactory.get_item(child_node.args[1])) arg = child_node.args[0] else: + logger.warning( + f"fx node {child_node.name, child_node.op, child_node.target} " + "can't be handled correctly for postproc module" + ) break # if we couldn't hit one of the "decisive" outcomes (constant, placeholder or module), return "not found"