diff --git a/torchrec/distributed/train_pipeline/tracing.py b/torchrec/distributed/train_pipeline/tracing.py index c4a95756c..f107f784c 100644 --- a/torchrec/distributed/train_pipeline/tracing.py +++ b/torchrec/distributed/train_pipeline/tracing.py @@ -302,7 +302,11 @@ def _handle_module( if not self._pipeline_postproc: logger.warning( - f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" + f"Found module {postproc_module} that potentially modifies input KJT. " + "Train pipeline initialized with `pipeline_postproc=False` (default), " + "so we assume KJT input modification. " + "To allow torchrec to check if this module can be safely pipelined, " + "please set `pipeline_postproc=True`" ) return None @@ -341,11 +345,10 @@ def _handle_module( ) if num_found_safe_postproc_args == total_num_args: logger.info( - f"""Module {postproc_module} is a valid postproc module (no - trainable params and inputs can be derived from train batch input - via a series of either valid postproc modules or non-modifying - transformations) and will be applied during sparse data dist - stage""" + f"Module {postproc_module} is a valid postproc module (no " + "trainable params and inputs can be derived from train batch input " + "via a series of either valid postproc modules or non-modifying " + "transformations) and will be applied during sparse data dist stage" ) pipelined_postproc_module = PipelinedPostproc( @@ -449,6 +452,10 @@ def _get_node_args_helper_inner( arg_info.add_step(ArgInfoStepFactory.get_item(child_node.args[1])) arg = child_node.args[0] else: + logger.warning( + f"fx node {child_node.name, child_node.op, child_node.target} " + "can't be handled correctly for postproc module" + ) break # if we couldn't hit one of the "decisive" outcomes (constant, placeholder or module), return "not found"